You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
136 lines
5.4 KiB
136 lines
5.4 KiB
//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SCF/EDSC/Builders.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
|
|
mlir::scf::LoopNest
|
|
mlir::edsc::loopNestBuilder(ValueRange lbs, ValueRange ubs, ValueRange steps,
|
|
function_ref<void(ValueRange)> fun) {
|
|
// Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
|
|
// the expected function interface.
|
|
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
|
|
return mlir::scf::buildLoopNest(
|
|
ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
|
|
steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) {
|
|
ScopedContext context(builder, loc);
|
|
if (fun)
|
|
fun(ivs);
|
|
});
|
|
}
|
|
|
|
mlir::scf::LoopNest
|
|
mlir::edsc::loopNestBuilder(Value lb, Value ub, Value step,
|
|
function_ref<void(Value)> fun) {
|
|
// Delegates to the ValueRange-based version by wrapping the lambda.
|
|
auto wrapper = [&](ValueRange ivs) {
|
|
assert(ivs.size() == 1);
|
|
if (fun)
|
|
fun(ivs[0]);
|
|
};
|
|
return loopNestBuilder(ValueRange(lb), ValueRange(ub), ValueRange(step),
|
|
wrapper);
|
|
}
|
|
|
|
mlir::scf::LoopNest mlir::edsc::loopNestBuilder(
|
|
Value lb, Value ub, Value step, ValueRange iterArgInitValues,
|
|
function_ref<scf::ValueVector(Value, ValueRange)> fun) {
|
|
// Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
|
|
// the expected function interface.
|
|
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
|
|
return mlir::scf::buildLoopNest(
|
|
ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lb, ub,
|
|
step, iterArgInitValues,
|
|
[&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
|
|
assert(ivs.size() == 1 && "expected one induction variable");
|
|
ScopedContext context(builder, loc);
|
|
if (fun)
|
|
return fun(ivs[0], args);
|
|
return scf::ValueVector(iterArgInitValues.begin(),
|
|
iterArgInitValues.end());
|
|
});
|
|
}
|
|
|
|
mlir::scf::LoopNest mlir::edsc::loopNestBuilder(
|
|
ValueRange lbs, ValueRange ubs, ValueRange steps,
|
|
ValueRange iterArgInitValues,
|
|
function_ref<scf::ValueVector(ValueRange, ValueRange)> fun) {
|
|
// Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
|
|
// the expected function interface.
|
|
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
|
|
return mlir::scf::buildLoopNest(
|
|
ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
|
|
steps, iterArgInitValues,
|
|
[&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
|
|
ScopedContext context(builder, loc);
|
|
if (fun)
|
|
return fun(ivs, args);
|
|
return scf::ValueVector(iterArgInitValues.begin(),
|
|
iterArgInitValues.end());
|
|
});
|
|
}
|
|
|
|
static std::function<void(OpBuilder &, Location)>
|
|
wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
|
|
(void)expectedTypes;
|
|
return [=](OpBuilder &builder, Location loc) {
|
|
ScopedContext context(builder, loc);
|
|
scf::ValueVector returned = body();
|
|
assert(ValueRange(returned).getTypes() == expectedTypes &&
|
|
"'if' body builder returned values of unexpected type");
|
|
builder.create<scf::YieldOp>(loc, returned);
|
|
};
|
|
}
|
|
|
|
ValueRange
|
|
mlir::edsc::conditionBuilder(TypeRange results, Value condition,
|
|
function_ref<scf::ValueVector()> thenBody,
|
|
function_ref<scf::ValueVector()> elseBody,
|
|
scf::IfOp *ifOp) {
|
|
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
|
|
assert(thenBody && "thenBody is mandatory");
|
|
|
|
auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
|
|
ScopedContext::getLocation(), results, condition,
|
|
wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
|
|
if (ifOp)
|
|
*ifOp = newOp;
|
|
return newOp.getResults();
|
|
}
|
|
|
|
static std::function<void(OpBuilder &, Location)>
|
|
wrapZeroResultIfBody(function_ref<void()> body) {
|
|
return [=](OpBuilder &builder, Location loc) {
|
|
ScopedContext context(builder, loc);
|
|
body();
|
|
builder.create<scf::YieldOp>(loc);
|
|
};
|
|
}
|
|
|
|
ValueRange mlir::edsc::conditionBuilder(Value condition,
|
|
function_ref<void()> thenBody,
|
|
function_ref<void()> elseBody,
|
|
scf::IfOp *ifOp) {
|
|
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
|
|
assert(thenBody && "thenBody is mandatory");
|
|
|
|
auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
|
|
ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
|
|
elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
|
|
wrapZeroResultIfBody(elseBody))
|
|
: llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
|
|
if (ifOp)
|
|
*ifOp = newOp;
|
|
return {};
|
|
}
|