You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
5.3 KiB
137 lines
5.3 KiB
//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
|
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassRegistry.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
class ConvertCstrBroadcastableOp
|
|
: public OpRewritePattern<shape::CstrBroadcastableOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op.getType().isa<shape::ShapeType>() ||
|
|
op.lhs().getType().isa<shape::ShapeType>() ||
|
|
op.rhs().getType().isa<shape::ShapeType>()) {
|
|
return rewriter.notifyMatchFailure(
|
|
op, "cannot convert error-propagating shapes");
|
|
}
|
|
|
|
auto loc = op.getLoc();
|
|
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
|
|
|
// Find smaller and greater rank and extent tensor.
|
|
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
|
|
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
|
|
Value lhsRankULE =
|
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
|
|
Type indexTy = rewriter.getIndexType();
|
|
Value lesserRank =
|
|
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
|
|
Value greaterRank =
|
|
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
|
|
Value lesserRankOperand =
|
|
rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
|
|
Value greaterRankOperand =
|
|
rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
|
|
|
|
Value rankDiff =
|
|
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
|
|
|
|
// Generate code to compare the shapes extent by extent, and emit errors for
|
|
// non-broadcast-compatible shapes.
|
|
// Two extents are broadcast-compatible if
|
|
// 1. they are both equal, or
|
|
// 2. at least one of them is 1.
|
|
|
|
rewriter.create<scf::ForOp>(
|
|
loc, rankDiff, greaterRank, one, llvm::None,
|
|
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
|
|
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
|
|
loc, greaterRankOperand, ValueRange{iv});
|
|
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
|
|
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
|
loc, lesserRankOperand, ValueRange{ivShifted});
|
|
|
|
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
|
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
|
|
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
|
|
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
|
|
Value extentsAgree =
|
|
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
|
|
lesserRankOperandExtent);
|
|
auto broadcastIsValid =
|
|
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
|
|
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
|
|
lesserRankOperandExtentIsOne));
|
|
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
|
|
b.create<scf::YieldOp>(loc);
|
|
});
|
|
|
|
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
|
|
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::populateConvertShapeConstraintsConversionPatterns(
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
|
patterns.insert<ConvertCstrBroadcastableOp>(ctx);
|
|
patterns.insert<ConvertCstrRequireOp>(ctx);
|
|
}
|
|
|
|
namespace {
|
|
// This pass eliminates shape constraints from the program, converting them to
|
|
// eager (side-effecting) error handling code. After eager error handling code
|
|
// is emitted, witnesses are satisfied, so they are replace with
|
|
// `shape.const_witness true`.
|
|
class ConvertShapeConstraints
|
|
: public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
|
|
void runOnOperation() override {
|
|
auto func = getOperation();
|
|
auto *context = &getContext();
|
|
|
|
OwningRewritePatternList patterns;
|
|
populateConvertShapeConstraintsConversionPatterns(patterns, context);
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>>
|
|
mlir::createConvertShapeConstraintsPass() {
|
|
return std::make_unique<ConvertShapeConstraints>();
|
|
}
|