You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

345 lines
13 KiB

//===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::LLVM;
using namespace mlir::linalg;
using llvm_add = ValueBuilder<LLVM::AddOp>;
using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
using llvm_gep = ValueBuilder<LLVM::GEPOp>;
using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
using llvm_call = OperationBuilder<LLVM::CallOp>;
using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
using llvm_load = ValueBuilder<LLVM::LoadOp>;
using llvm_store = OperationBuilder<LLVM::StoreOp>;
using llvm_select = ValueBuilder<LLVM::SelectOp>;
using llvm_mul = ValueBuilder<LLVM::MulOp>;
using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
using llvm_sub = ValueBuilder<LLVM::SubOp>;
using llvm_undef = ValueBuilder<LLVM::UndefOp>;
using llvm_urem = ValueBuilder<LLVM::URemOp>;
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
template <typename T>
static LLVMType getPtrToElementType(T containerType,
LLVMTypeConverter &lowering) {
return lowering.convertType(containerType.getElementType())
.template cast<LLVMType>()
.getPointerTo();
}
/// Convert the given range descriptor type to the LLVMIR dialect.
/// Range descriptor contains the range bounds and the step as 64-bit integers.
///
/// struct {
/// int64_t min;
/// int64_t max;
/// int64_t step;
/// };
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
auto *context = t.getContext();
auto int64Ty = converter.convertType(IntegerType::get(64, context))
.cast<LLVM::LLVMType>();
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
}
namespace {
/// EDSC-compatible wrapper for MemRefDescriptor.
class BaseViewConversionHelper {
public:
BaseViewConversionHelper(Type type)
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
BaseViewConversionHelper(Value v) : d(v) {}
/// Wrappers around MemRefDescriptor that use EDSC builder and location.
Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
Value offset() { return d.offset(rewriter(), loc()); }
void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
void setConstantSize(unsigned i, int64_t v) {
d.setConstantSize(rewriter(), loc(), i, v);
}
Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
void setConstantStride(unsigned i, int64_t v) {
d.setConstantStride(rewriter(), loc(), i, v);
}
operator Value() { return d; }
private:
OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
Location loc() { return ScopedContext::getLocation(); }
MemRefDescriptor d;
};
// RangeOp creates a new range descriptor.
class RangeOpConversion : public ConvertToLLVMPattern {
public:
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy = convertRangeType(
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
edsc::ScopedContext context(rewriter, op->getLoc());
// Fill in an aggregate value of the descriptor.
RangeOpAdaptor adaptor(operands);
Value desc = llvm_undef(rangeDescriptorTy);
desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
rewriter.replaceOp(op, desc);
return success();
}
};
// ReshapeOp creates a new view descriptor of the proper rank.
// For now, the only conversion supported is for target MemRef with static sizes
// and strides.
class ReshapeOpConversion : public ConvertToLLVMPattern {
public:
explicit ReshapeOpConversion(MLIRContext *context,
LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
lowering_) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reshapeOp = cast<ReshapeOp>(op);
MemRefType dstType = reshapeOp.getResultType();
if (!dstType.hasStaticShape())
return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(dstType, strides, offset);
if (failed(res) || llvm::any_of(strides, [](int64_t val) {
return ShapedType::isDynamicStrideOrOffset(val);
}))
return failure();
edsc::ScopedContext context(rewriter, op->getLoc());
ReshapeOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.src());
BaseViewConversionHelper desc(typeConverter->convertType(dstType));
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
desc.setOffset(baseDesc.offset());
for (auto en : llvm::enumerate(dstType.getShape()))
desc.setConstantSize(en.index(), en.value());
for (auto en : llvm::enumerate(strides))
desc.setConstantStride(en.index(), en.value());
rewriter.replaceOp(op, {desc});
return success();
}
};
/// Conversion pattern that transforms a linalg.slice op into:
/// 1. An "undef" value for the ViewDescriptor.
/// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
/// and stride corresponding to the region of memory within the bounds of
/// the parent view.
/// The linalg.slice op is replaced by the alloca'ed pointer.
class SliceOpConversion : public ConvertToLLVMPattern {
public:
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
SliceOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
BaseViewConversionHelper desc(
typeConverter->convertType(sliceOp.getShapedType()));
// TODO: extract sizes and emit asserts.
SmallVector<Value, 4> strides(memRefType.getRank());
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
strides[i] = baseDesc.stride(i);
auto pos = [&rewriter](ArrayRef<int64_t> values) {
return rewriter.getI64ArrayAttr(values);
};
// Compute base offset.
Value baseOffset = baseDesc.offset();
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
Value indexing = adaptor.indexings()[i];
Value min = indexing;
if (sliceOp.indexing(i).getType().isa<RangeType>())
min = llvm_extractvalue(int64Ty, indexing, pos(0));
baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
}
// Insert the base and aligned pointers.
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
// Insert base offset.
desc.setOffset(baseOffset);
// Corner case, no sizes or strides: early return the descriptor.
if (sliceOp.getShapedType().getRank() == 0)
return rewriter.replaceOp(op, {desc}), success();
Value zero = llvm_constant(
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
// Compute and insert view sizes (max - min along the range) and strides.
// Skip the non-range operands as they will be projected away from the view.
int numNewDims = 0;
for (auto en : llvm::enumerate(sliceOp.indexings())) {
Value indexing = en.value();
if (indexing.getType().isa<RangeType>()) {
int rank = en.index();
Value rangeDescriptor = adaptor.indexings()[rank];
Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
Value baseSize = baseDesc.size(rank);
// Bound upper by base view upper bound.
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
baseSize);
Value size = llvm_sub(max, min);
// Bound lower by zero.
size =
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
Value stride = llvm_mul(strides[rank], step);
desc.setSize(numNewDims, size);
desc.setStride(numNewDims, stride);
++numNewDims;
}
}
rewriter.replaceOp(op, {desc});
return success();
}
};
// YieldOp produces and LLVM::ReturnOp.
class YieldOpConversion : public ConvertToLLVMPattern {
public:
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
lowering_) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return success();
}
};
} // namespace
/// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
YieldOpConversion>(ctx, converter);
// Populate the type conversions for the linalg types.
converter.addConversion(
[&](RangeType type) { return convertRangeType(type, converter); });
}
namespace {
struct ConvertLinalgToLLVMPass
: public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
void runOnOperation() override;
};
} // namespace
void ConvertLinalgToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateVectorToSCFConversionPatterns(patterns, &getContext());
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns);
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
return std::make_unique<ConvertLinalgToLLVMPass>();
}