You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.4 KiB

//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
using namespace mlir;
using namespace mlir::edsc;
mlir::scf::LoopNest
mlir::edsc::loopNestBuilder(ValueRange lbs, ValueRange ubs, ValueRange steps,
function_ref<void(ValueRange)> fun) {
// Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
// the expected function interface.
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
return mlir::scf::buildLoopNest(
ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) {
ScopedContext context(builder, loc);
if (fun)
fun(ivs);
});
}
mlir::scf::LoopNest
mlir::edsc::loopNestBuilder(Value lb, Value ub, Value step,
function_ref<void(Value)> fun) {
// Delegates to the ValueRange-based version by wrapping the lambda.
auto wrapper = [&](ValueRange ivs) {
assert(ivs.size() == 1);
if (fun)
fun(ivs[0]);
};
return loopNestBuilder(ValueRange(lb), ValueRange(ub), ValueRange(step),
wrapper);
}
mlir::scf::LoopNest mlir::edsc::loopNestBuilder(
Value lb, Value ub, Value step, ValueRange iterArgInitValues,
function_ref<scf::ValueVector(Value, ValueRange)> fun) {
// Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
// the expected function interface.
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
return mlir::scf::buildLoopNest(
ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lb, ub,
step, iterArgInitValues,
[&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
assert(ivs.size() == 1 && "expected one induction variable");
ScopedContext context(builder, loc);
if (fun)
return fun(ivs[0], args);
return scf::ValueVector(iterArgInitValues.begin(),
iterArgInitValues.end());
});
}
mlir::scf::LoopNest mlir::edsc::loopNestBuilder(
ValueRange lbs, ValueRange ubs, ValueRange steps,
ValueRange iterArgInitValues,
function_ref<scf::ValueVector(ValueRange, ValueRange)> fun) {
// Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
// the expected function interface.
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
return mlir::scf::buildLoopNest(
ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
steps, iterArgInitValues,
[&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
ScopedContext context(builder, loc);
if (fun)
return fun(ivs, args);
return scf::ValueVector(iterArgInitValues.begin(),
iterArgInitValues.end());
});
}
static std::function<void(OpBuilder &, Location)>
wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
(void)expectedTypes;
return [=](OpBuilder &builder, Location loc) {
ScopedContext context(builder, loc);
scf::ValueVector returned = body();
assert(ValueRange(returned).getTypes() == expectedTypes &&
"'if' body builder returned values of unexpected type");
builder.create<scf::YieldOp>(loc, returned);
};
}
ValueRange
mlir::edsc::conditionBuilder(TypeRange results, Value condition,
function_ref<scf::ValueVector()> thenBody,
function_ref<scf::ValueVector()> elseBody,
scf::IfOp *ifOp) {
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
assert(thenBody && "thenBody is mandatory");
auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), results, condition,
wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
if (ifOp)
*ifOp = newOp;
return newOp.getResults();
}
static std::function<void(OpBuilder &, Location)>
wrapZeroResultIfBody(function_ref<void()> body) {
return [=](OpBuilder &builder, Location loc) {
ScopedContext context(builder, loc);
body();
builder.create<scf::YieldOp>(loc);
};
}
ValueRange mlir::edsc::conditionBuilder(Value condition,
function_ref<void()> thenBody,
function_ref<void()> elseBody,
scf::IfOp *ifOp) {
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
assert(thenBody && "thenBody is mandatory");
auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
wrapZeroResultIfBody(elseBody))
: llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
if (ifOp)
*ifOp = newOp;
return {};
}