//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Builders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" using namespace mlir; using namespace mlir::edsc; void mlir::edsc::affineLoopNestBuilder( ValueRange lbs, ValueRange ubs, ArrayRef steps, function_ref bodyBuilderFn) { assert(ScopedContext::getContext() && "EDSC ScopedContext not set up"); // Wrap the body builder function into an interface compatible with the main // builder. auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs) { ScopedContext context(nestedBuilder, nestedLoc); bodyBuilderFn(ivs); }; function_ref wrapper; if (bodyBuilderFn) wrapper = wrappedBuilderFn; // Extract the builder, location and construct the loop nest. OpBuilder &builder = ScopedContext::getBuilderRef(); Location loc = ScopedContext::getLocation(); buildAffineLoopNest(builder, loc, lbs, ubs, steps, wrapper); } void mlir::edsc::affineLoopBuilder(ValueRange lbs, ValueRange ubs, int64_t step, function_ref bodyBuilderFn) { // Fetch the builder and location. assert(ScopedContext::getContext() && "EDSC ScopedContext not set up"); OpBuilder &builder = ScopedContext::getBuilderRef(); Location loc = ScopedContext::getLocation(); // Create the actual loop and call the body builder, if provided, after // updating the scoped context. builder.create( loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs, builder.getMultiDimIdentityMap(ubs.size()), step, llvm::None, [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange itrArgs) { if (bodyBuilderFn) { ScopedContext nestedContext(nestedBuilder, nestedLoc); OpBuilder::InsertionGuard guard(nestedBuilder); bodyBuilderFn(iv); } nestedBuilder.create(nestedLoc); }); } void mlir::edsc::affineLoopBuilder( ValueRange lbs, ValueRange ubs, int64_t step, ValueRange iterArgs, function_ref bodyBuilderFn) { // Fetch the builder and location. assert(ScopedContext::getContext() && "EDSC ScopedContext not set up"); OpBuilder &builder = ScopedContext::getBuilderRef(); Location loc = ScopedContext::getLocation(); // Create the actual loop and call the body builder, if provided, after // updating the scoped context. builder.create( loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs, builder.getMultiDimIdentityMap(ubs.size()), step, iterArgs, [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange itrArgs) { if (bodyBuilderFn) { ScopedContext nestedContext(nestedBuilder, nestedLoc); OpBuilder::InsertionGuard guard(nestedBuilder); bodyBuilderFn(iv, itrArgs); } else if (itrArgs.empty()) nestedBuilder.create(nestedLoc); }); } static std::pair categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims, unsigned &numSymbols) { AffineExpr d; Value resultVal = nullptr; if (auto constant = val.getDefiningOp()) { d = getAffineConstantExpr(constant.getValue(), context); } else if (isValidSymbol(val) && !isValidDim(val)) { d = getAffineSymbolExpr(numSymbols++, context); resultVal = val; } else { d = getAffineDimExpr(numDims++, context); resultVal = val; } return std::make_pair(d, resultVal); } static Value createBinaryIndexHandle( Value lhs, Value rhs, function_ref affCombiner) { MLIRContext *context = ScopedContext::getContext(); unsigned numDims = 0, numSymbols = 0; AffineExpr d0, d1; Value v0, v1; std::tie(d0, v0) = categorizeValueByAffineType(context, lhs, numDims, numSymbols); std::tie(d1, v1) = categorizeValueByAffineType(context, rhs, numDims, numSymbols); SmallVector operands; if (v0) operands.push_back(v0); if (v1) operands.push_back(v1); auto map = AffineMap::get(numDims, numSymbols, affCombiner(d0, d1)); // TODO: createOrFold when available. Operation *op = makeComposedAffineApply(ScopedContext::getBuilderRef(), ScopedContext::getLocation(), map, operands) .getOperation(); assert(op->getNumResults() == 1 && "Expected single result AffineApply"); return op->getResult(0); } template static Value createBinaryHandle( Value lhs, Value rhs, function_ref affCombiner) { auto thisType = lhs.getType(); auto thatType = rhs.getType(); assert(thisType == thatType && "cannot mix types in operators"); (void)thisType; (void)thatType; if (thisType.isIndex()) { return createBinaryIndexHandle(lhs, rhs, affCombiner); } else if (thisType.isSignlessInteger()) { return ValueBuilder(lhs, rhs); } else if (thisType.isa()) { return ValueBuilder(lhs, rhs); } else if (thisType.isa()) { auto aggregateType = thisType.cast(); if (aggregateType.getElementType().isSignlessInteger()) return ValueBuilder(lhs, rhs); else if (aggregateType.getElementType().isa()) return ValueBuilder(lhs, rhs); } llvm_unreachable("failed to create a Value"); } Value mlir::edsc::op::operator+(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; }); } Value mlir::edsc::op::operator-(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; }); } Value mlir::edsc::op::operator*(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; }); } Value mlir::edsc::op::operator/(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr { llvm_unreachable("only exprs of non-index type support operator/"); }); } Value mlir::edsc::op::operator%(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; }); } Value mlir::edsc::op::floorDiv(Value lhs, Value rhs) { return createBinaryIndexHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); }); } Value mlir::edsc::op::ceilDiv(Value lhs, Value rhs) { return createBinaryIndexHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); }); } Value mlir::edsc::op::negate(Value value) { assert(value.getType().isInteger(1) && "expected boolean expression"); return ValueBuilder(1, 1) - value; } Value mlir::edsc::op::operator&&(Value lhs, Value rhs) { assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); return ValueBuilder(lhs, rhs); } Value mlir::edsc::op::operator||(Value lhs, Value rhs) { assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); return ValueBuilder(lhs, rhs); } static Value createIComparisonExpr(CmpIPredicate predicate, Value lhs, Value rhs) { auto lhsType = lhs.getType(); auto rhsType = rhs.getType(); (void)lhsType; (void)rhsType; assert(lhsType == rhsType && "cannot mix types in operators"); assert((lhsType.isa() || lhsType.isSignlessInteger()) && "only integer comparisons are supported"); return ScopedContext::getBuilderRef().create( ScopedContext::getLocation(), predicate, lhs, rhs); } static Value createFComparisonExpr(CmpFPredicate predicate, Value lhs, Value rhs) { auto lhsType = lhs.getType(); auto rhsType = rhs.getType(); (void)lhsType; (void)rhsType; assert(lhsType == rhsType && "cannot mix types in operators"); assert(lhsType.isa() && "only float comparisons are supported"); return ScopedContext::getBuilderRef().create( ScopedContext::getLocation(), predicate, lhs, rhs); } // All floating point comparison are ordered through EDSL Value mlir::edsc::op::eq(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs) : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs); } Value mlir::edsc::op::ne(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); } Value mlir::edsc::op::slt(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) : createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); } Value mlir::edsc::op::sle(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); } Value mlir::edsc::op::sgt(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); } Value mlir::edsc::op::sge(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs); } Value mlir::edsc::op::ult(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) : createIComparisonExpr(CmpIPredicate::ult, lhs, rhs); } Value mlir::edsc::op::ule(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::ule, lhs, rhs); } Value mlir::edsc::op::ugt(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) : createIComparisonExpr(CmpIPredicate::ugt, lhs, rhs); } Value mlir::edsc::op::uge(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::uge, lhs, rhs); }