You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
345 lines
13 KiB
345 lines
13 KiB
//===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
|
|
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
|
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Support/Allocator.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
using namespace mlir::edsc::intrinsics;
|
|
using namespace mlir::LLVM;
|
|
using namespace mlir::linalg;
|
|
|
|
using llvm_add = ValueBuilder<LLVM::AddOp>;
|
|
using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
|
|
using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
|
|
using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
|
|
using llvm_gep = ValueBuilder<LLVM::GEPOp>;
|
|
using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
|
|
using llvm_call = OperationBuilder<LLVM::CallOp>;
|
|
using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
|
|
using llvm_load = ValueBuilder<LLVM::LoadOp>;
|
|
using llvm_store = OperationBuilder<LLVM::StoreOp>;
|
|
using llvm_select = ValueBuilder<LLVM::SelectOp>;
|
|
using llvm_mul = ValueBuilder<LLVM::MulOp>;
|
|
using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
|
|
using llvm_sub = ValueBuilder<LLVM::SubOp>;
|
|
using llvm_undef = ValueBuilder<LLVM::UndefOp>;
|
|
using llvm_urem = ValueBuilder<LLVM::URemOp>;
|
|
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
|
|
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
|
|
|
|
template <typename T>
|
|
static LLVMType getPtrToElementType(T containerType,
|
|
LLVMTypeConverter &lowering) {
|
|
return lowering.convertType(containerType.getElementType())
|
|
.template cast<LLVMType>()
|
|
.getPointerTo();
|
|
}
|
|
|
|
/// Convert the given range descriptor type to the LLVMIR dialect.
|
|
/// Range descriptor contains the range bounds and the step as 64-bit integers.
|
|
///
|
|
/// struct {
|
|
/// int64_t min;
|
|
/// int64_t max;
|
|
/// int64_t step;
|
|
/// };
|
|
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
|
|
auto *context = t.getContext();
|
|
auto int64Ty = converter.convertType(IntegerType::get(64, context))
|
|
.cast<LLVM::LLVMType>();
|
|
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
|
}
|
|
|
|
namespace {
|
|
/// EDSC-compatible wrapper for MemRefDescriptor.
|
|
class BaseViewConversionHelper {
|
|
public:
|
|
BaseViewConversionHelper(Type type)
|
|
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
|
|
|
|
BaseViewConversionHelper(Value v) : d(v) {}
|
|
|
|
/// Wrappers around MemRefDescriptor that use EDSC builder and location.
|
|
Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
|
|
void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
|
|
Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
|
|
void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
|
|
Value offset() { return d.offset(rewriter(), loc()); }
|
|
void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
|
|
Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
|
|
void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
|
|
void setConstantSize(unsigned i, int64_t v) {
|
|
d.setConstantSize(rewriter(), loc(), i, v);
|
|
}
|
|
Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
|
|
void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
|
|
void setConstantStride(unsigned i, int64_t v) {
|
|
d.setConstantStride(rewriter(), loc(), i, v);
|
|
}
|
|
|
|
operator Value() { return d; }
|
|
|
|
private:
|
|
OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
|
|
Location loc() { return ScopedContext::getLocation(); }
|
|
|
|
MemRefDescriptor d;
|
|
};
|
|
|
|
// RangeOp creates a new range descriptor.
|
|
class RangeOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto rangeOp = cast<RangeOp>(op);
|
|
auto rangeDescriptorTy = convertRangeType(
|
|
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
|
|
// Fill in an aggregate value of the descriptor.
|
|
RangeOpAdaptor adaptor(operands);
|
|
Value desc = llvm_undef(rangeDescriptorTy);
|
|
desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
|
|
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
|
|
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
|
|
rewriter.replaceOp(op, desc);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// ReshapeOp creates a new view descriptor of the proper rank.
|
|
// For now, the only conversion supported is for target MemRef with static sizes
|
|
// and strides.
|
|
class ReshapeOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit ReshapeOpConversion(MLIRContext *context,
|
|
LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
|
|
lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto reshapeOp = cast<ReshapeOp>(op);
|
|
MemRefType dstType = reshapeOp.getResultType();
|
|
|
|
if (!dstType.hasStaticShape())
|
|
return failure();
|
|
|
|
int64_t offset;
|
|
SmallVector<int64_t, 4> strides;
|
|
auto res = getStridesAndOffset(dstType, strides, offset);
|
|
if (failed(res) || llvm::any_of(strides, [](int64_t val) {
|
|
return ShapedType::isDynamicStrideOrOffset(val);
|
|
}))
|
|
return failure();
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
ReshapeOpAdaptor adaptor(operands);
|
|
BaseViewConversionHelper baseDesc(adaptor.src());
|
|
BaseViewConversionHelper desc(typeConverter->convertType(dstType));
|
|
desc.setAllocatedPtr(baseDesc.allocatedPtr());
|
|
desc.setAlignedPtr(baseDesc.alignedPtr());
|
|
desc.setOffset(baseDesc.offset());
|
|
for (auto en : llvm::enumerate(dstType.getShape()))
|
|
desc.setConstantSize(en.index(), en.value());
|
|
for (auto en : llvm::enumerate(strides))
|
|
desc.setConstantStride(en.index(), en.value());
|
|
rewriter.replaceOp(op, {desc});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms a linalg.slice op into:
|
|
/// 1. An "undef" value for the ViewDescriptor.
|
|
/// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
|
/// and stride corresponding to the region of memory within the bounds of
|
|
/// the parent view.
|
|
/// The linalg.slice op is replaced by the alloca'ed pointer.
|
|
class SliceOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
SliceOpAdaptor adaptor(operands);
|
|
BaseViewConversionHelper baseDesc(adaptor.view());
|
|
|
|
auto sliceOp = cast<SliceOp>(op);
|
|
auto memRefType = sliceOp.getBaseViewType();
|
|
auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
|
|
.cast<LLVM::LLVMType>();
|
|
|
|
BaseViewConversionHelper desc(
|
|
typeConverter->convertType(sliceOp.getShapedType()));
|
|
|
|
// TODO: extract sizes and emit asserts.
|
|
SmallVector<Value, 4> strides(memRefType.getRank());
|
|
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
|
|
strides[i] = baseDesc.stride(i);
|
|
|
|
auto pos = [&rewriter](ArrayRef<int64_t> values) {
|
|
return rewriter.getI64ArrayAttr(values);
|
|
};
|
|
|
|
// Compute base offset.
|
|
Value baseOffset = baseDesc.offset();
|
|
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
|
|
Value indexing = adaptor.indexings()[i];
|
|
Value min = indexing;
|
|
if (sliceOp.indexing(i).getType().isa<RangeType>())
|
|
min = llvm_extractvalue(int64Ty, indexing, pos(0));
|
|
baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
|
|
}
|
|
|
|
// Insert the base and aligned pointers.
|
|
desc.setAllocatedPtr(baseDesc.allocatedPtr());
|
|
desc.setAlignedPtr(baseDesc.alignedPtr());
|
|
|
|
// Insert base offset.
|
|
desc.setOffset(baseOffset);
|
|
|
|
// Corner case, no sizes or strides: early return the descriptor.
|
|
if (sliceOp.getShapedType().getRank() == 0)
|
|
return rewriter.replaceOp(op, {desc}), success();
|
|
|
|
Value zero = llvm_constant(
|
|
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
|
// Compute and insert view sizes (max - min along the range) and strides.
|
|
// Skip the non-range operands as they will be projected away from the view.
|
|
int numNewDims = 0;
|
|
for (auto en : llvm::enumerate(sliceOp.indexings())) {
|
|
Value indexing = en.value();
|
|
if (indexing.getType().isa<RangeType>()) {
|
|
int rank = en.index();
|
|
Value rangeDescriptor = adaptor.indexings()[rank];
|
|
Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
|
|
Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
|
|
Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
|
|
Value baseSize = baseDesc.size(rank);
|
|
|
|
// Bound upper by base view upper bound.
|
|
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
|
|
baseSize);
|
|
Value size = llvm_sub(max, min);
|
|
// Bound lower by zero.
|
|
size =
|
|
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
|
|
Value stride = llvm_mul(strides[rank], step);
|
|
desc.setSize(numNewDims, size);
|
|
desc.setStride(numNewDims, stride);
|
|
++numNewDims;
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOp(op, {desc});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// YieldOp produces and LLVM::ReturnOp.
|
|
class YieldOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
|
|
lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
|
void mlir::populateLinalgToLLVMConversionPatterns(
|
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
|
MLIRContext *ctx) {
|
|
patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
|
|
YieldOpConversion>(ctx, converter);
|
|
|
|
// Populate the type conversions for the linalg types.
|
|
converter.addConversion(
|
|
[&](RangeType type) { return convertRangeType(type, converter); });
|
|
}
|
|
|
|
namespace {
|
|
struct ConvertLinalgToLLVMPass
|
|
: public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void ConvertLinalgToLLVMPass::runOnOperation() {
|
|
auto module = getOperation();
|
|
|
|
// Convert to the LLVM IR dialect using the converter defined above.
|
|
OwningRewritePatternList patterns;
|
|
LLVMTypeConverter converter(&getContext());
|
|
populateAffineToStdConversionPatterns(patterns, &getContext());
|
|
populateLoopToStdConversionPatterns(patterns, &getContext());
|
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
|
populateVectorToSCFConversionPatterns(patterns, &getContext());
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
populateVectorToLLVMConversionPatterns(converter, patterns);
|
|
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
|
|
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
|
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
|
|
return std::make_unique<ConvertLinalgToLLVMPass>();
|
|
}
|