You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
733 lines
26 KiB
733 lines
26 KiB
//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// A striped difference-bound matrix (SDBM) expression is a constant expression,
|
|
// an identifier, a binary expression with constant RHS and +, stripe operators
|
|
// or a difference expression between two identifiers.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SDBM/SDBMExpr.h"
|
|
#include "SDBMExprDetail.h"
|
|
#include "mlir/Dialect/SDBM/SDBMDialect.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineExprVisitor.h"
|
|
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// A simple compositional matcher for AffineExpr
|
|
///
|
|
/// Example usage:
|
|
///
|
|
/// ```c++
|
|
/// AffineExprMatcher x, C, m;
|
|
/// AffineExprMatcher pattern1 = ((x % C) * m) + x;
|
|
/// AffineExprMatcher pattern2 = x + ((x % C) * m);
|
|
/// if (pattern1.match(expr) || pattern2.match(expr)) {
|
|
/// ...
|
|
/// }
|
|
/// ```
|
|
class AffineExprMatcherStorage;
|
|
class AffineExprMatcher {
|
|
public:
|
|
AffineExprMatcher();
|
|
AffineExprMatcher(const AffineExprMatcher &other);
|
|
|
|
AffineExprMatcher operator+(AffineExprMatcher other) {
|
|
return AffineExprMatcher(AffineExprKind::Add, *this, other);
|
|
}
|
|
AffineExprMatcher operator*(AffineExprMatcher other) {
|
|
return AffineExprMatcher(AffineExprKind::Mul, *this, other);
|
|
}
|
|
AffineExprMatcher floorDiv(AffineExprMatcher other) {
|
|
return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
|
|
}
|
|
AffineExprMatcher ceilDiv(AffineExprMatcher other) {
|
|
return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
|
|
}
|
|
AffineExprMatcher operator%(AffineExprMatcher other) {
|
|
return AffineExprMatcher(AffineExprKind::Mod, *this, other);
|
|
}
|
|
|
|
AffineExpr match(AffineExpr expr);
|
|
AffineExpr matched();
|
|
Optional<int> getMatchedConstantValue();
|
|
|
|
private:
|
|
AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
|
|
AffineExprKind kind; // only used to match in binary op cases.
|
|
// A shared_ptr allows multiple references to same matcher storage without
|
|
// worrying about ownership or dealing with an arena. To be cleaned up if we
|
|
// go with this.
|
|
std::shared_ptr<AffineExprMatcherStorage> storage;
|
|
};
|
|
|
|
class AffineExprMatcherStorage {
|
|
public:
|
|
AffineExprMatcherStorage() {}
|
|
AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
|
|
: subExprs(other.subExprs.begin(), other.subExprs.end()),
|
|
matched(other.matched) {}
|
|
AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
|
|
: subExprs(exprs.begin(), exprs.end()) {}
|
|
AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
|
|
: subExprs({a, b}) {}
|
|
SmallVector<AffineExprMatcher, 0> subExprs;
|
|
AffineExpr matched;
|
|
};
|
|
} // namespace
|
|
|
|
AffineExprMatcher::AffineExprMatcher()
|
|
: kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
|
|
|
|
AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
|
|
: kind(other.kind), storage(other.storage) {}
|
|
|
|
Optional<int> AffineExprMatcher::getMatchedConstantValue() {
|
|
if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
|
|
return cst.getValue();
|
|
return None;
|
|
}
|
|
|
|
AffineExpr AffineExprMatcher::match(AffineExpr expr) {
|
|
if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
|
|
if (storage->matched)
|
|
if (storage->matched != expr)
|
|
return AffineExpr();
|
|
storage->matched = expr;
|
|
return storage->matched;
|
|
}
|
|
if (kind != expr.getKind()) {
|
|
return AffineExpr();
|
|
}
|
|
if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
|
|
if (!storage->subExprs.empty() &&
|
|
!storage->subExprs[0].match(bin.getLHS())) {
|
|
return AffineExpr();
|
|
}
|
|
if (!storage->subExprs.empty() &&
|
|
!storage->subExprs[1].match(bin.getRHS())) {
|
|
return AffineExpr();
|
|
}
|
|
if (storage->matched)
|
|
if (storage->matched != expr)
|
|
return AffineExpr();
|
|
storage->matched = expr;
|
|
return storage->matched;
|
|
}
|
|
llvm_unreachable("binary expected");
|
|
}
|
|
|
|
AffineExpr AffineExprMatcher::matched() { return storage->matched; }
|
|
|
|
AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
|
|
AffineExprMatcher b)
|
|
: kind(k), storage(new AffineExprMatcherStorage(a, b)) {
|
|
storage->subExprs.push_back(a);
|
|
storage->subExprs.push_back(b);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
|
|
|
|
MLIRContext *SDBMExpr::getContext() const {
|
|
return impl->dialect->getContext();
|
|
}
|
|
|
|
SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
|
|
|
|
void SDBMExpr::print(raw_ostream &os) const {
|
|
struct Printer : public SDBMVisitor<Printer> {
|
|
Printer(raw_ostream &ostream) : prn(ostream) {}
|
|
|
|
void visitSum(SDBMSumExpr expr) {
|
|
visit(expr.getLHS());
|
|
prn << " + ";
|
|
visit(expr.getRHS());
|
|
}
|
|
void visitDiff(SDBMDiffExpr expr) {
|
|
visit(expr.getLHS());
|
|
prn << " - ";
|
|
visit(expr.getRHS());
|
|
}
|
|
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
|
|
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
|
|
void visitStripe(SDBMStripeExpr expr) {
|
|
SDBMDirectExpr lhs = expr.getLHS();
|
|
bool isTerm = lhs.isa<SDBMTermExpr>();
|
|
if (!isTerm)
|
|
prn << '(';
|
|
visit(lhs);
|
|
if (!isTerm)
|
|
prn << ')';
|
|
prn << " # ";
|
|
visitConstant(expr.getStripeFactor());
|
|
}
|
|
void visitNeg(SDBMNegExpr expr) {
|
|
bool isSum = expr.getVar().isa<SDBMSumExpr>();
|
|
prn << '-';
|
|
if (isSum)
|
|
prn << '(';
|
|
visit(expr.getVar());
|
|
if (isSum)
|
|
prn << ')';
|
|
}
|
|
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
|
|
|
|
raw_ostream &prn;
|
|
};
|
|
Printer printer(os);
|
|
printer.visit(*this);
|
|
}
|
|
|
|
void SDBMExpr::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << '\n';
|
|
}
|
|
|
|
namespace {
|
|
// Helper class to perform negation of an SDBM expression.
|
|
struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
|
|
// Any term expression is wrapped into a negation expression.
|
|
// -(x) = -x
|
|
SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
|
|
// A negation expression is unwrapped.
|
|
// -(-x) = x
|
|
SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
|
|
// The value of the constant is negated.
|
|
SDBMExpr visitConstant(SDBMConstantExpr expr) {
|
|
return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
|
|
}
|
|
|
|
// Terms of a difference are interchanged. Since only the LHS of a diff
|
|
// expression is allowed to be a sum with a constant, we need to recreate the
|
|
// sum with the negated value:
|
|
// -((x + C) - y) = (y - C) - x.
|
|
SDBMExpr visitDiff(SDBMDiffExpr expr) {
|
|
// If the LHS is just a term, we can do straightforward interchange.
|
|
if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
|
|
return SDBMDiffExpr::get(expr.getRHS(), term);
|
|
|
|
auto sum = expr.getLHS().cast<SDBMSumExpr>();
|
|
auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
|
|
return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
|
|
sum.getLHS());
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMSumExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
|
|
assert(lhs && "expected SDBM variable expression");
|
|
assert(rhs && "expected SDBM constant");
|
|
|
|
// If LHS of a sum is another sum, fold the constant RHS parts.
|
|
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
|
|
lhs = lhsSum.getLHS();
|
|
rhs = SDBMConstantExpr::get(rhs.getDialect(),
|
|
rhs.getValue() + lhsSum.getRHS().getValue());
|
|
}
|
|
|
|
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
|
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
|
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
|
|
}
|
|
|
|
SDBMTermExpr SDBMSumExpr::getLHS() const {
|
|
return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
|
|
}
|
|
|
|
SDBMConstantExpr SDBMSumExpr::getRHS() const {
|
|
return static_cast<ImplType *>(impl)->rhs;
|
|
}
|
|
|
|
AffineExpr SDBMExpr::getAsAffineExpr() const {
|
|
struct Converter : public SDBMVisitor<Converter, AffineExpr> {
|
|
AffineExpr visitSum(SDBMSumExpr expr) {
|
|
AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
|
return lhs + rhs;
|
|
}
|
|
|
|
AffineExpr visitStripe(SDBMStripeExpr expr) {
|
|
AffineExpr lhs = visit(expr.getLHS()),
|
|
rhs = visit(expr.getStripeFactor());
|
|
return lhs - (lhs % rhs);
|
|
}
|
|
|
|
AffineExpr visitDiff(SDBMDiffExpr expr) {
|
|
AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
|
return lhs - rhs;
|
|
}
|
|
|
|
AffineExpr visitDim(SDBMDimExpr expr) {
|
|
return getAffineDimExpr(expr.getPosition(), expr.getContext());
|
|
}
|
|
|
|
AffineExpr visitSymbol(SDBMSymbolExpr expr) {
|
|
return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
|
|
}
|
|
|
|
AffineExpr visitNeg(SDBMNegExpr expr) {
|
|
return getAffineBinaryOpExpr(AffineExprKind::Mul,
|
|
getAffineConstantExpr(-1, expr.getContext()),
|
|
visit(expr.getVar()));
|
|
}
|
|
|
|
AffineExpr visitConstant(SDBMConstantExpr expr) {
|
|
return getAffineConstantExpr(expr.getValue(), expr.getContext());
|
|
}
|
|
} converter;
|
|
return converter.visit(*this);
|
|
}
|
|
|
|
// Given a direct expression `expr`, add the given constant to it and pass the
|
|
// resulting expression to `builder` before returning its result. If the
|
|
// expression is already a sum expression, update its constant and extract the
|
|
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
|
|
template <typename Result>
|
|
static Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant,
|
|
bool negated,
|
|
function_ref<Result(SDBMDirectExpr)> builder) {
|
|
SDBMDialect *dialect = expr.getDialect();
|
|
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
|
if (negated)
|
|
constant = sumExpr.getRHS().getValue() - constant;
|
|
else
|
|
constant += sumExpr.getRHS().getValue();
|
|
|
|
if (constant != 0) {
|
|
auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
|
|
SDBMConstantExpr::get(dialect, constant));
|
|
return builder(sum);
|
|
} else {
|
|
return builder(sumExpr.getLHS());
|
|
}
|
|
}
|
|
if (constant != 0)
|
|
return builder(SDBMSumExpr::get(
|
|
expr.cast<SDBMTermExpr>(),
|
|
SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
|
|
return expr;
|
|
}
|
|
|
|
// Construct an expression lhs + constant while maintaining the canonical form
|
|
// of the SDBM expressions, in particular sink the constant expression to the
|
|
// nearest sum expression in the left subtree of the expression tree.
|
|
static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
|
|
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
|
|
return addConstantAndSink<SDBMExpr>(
|
|
lhsDiff.getLHS(), constant, /*negated=*/false,
|
|
[lhsDiff](SDBMDirectExpr e) {
|
|
return SDBMDiffExpr::get(e, lhsDiff.getRHS());
|
|
});
|
|
if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
|
|
return addConstantAndSink<SDBMExpr>(
|
|
lhsNeg.getVar(), constant, /*negated=*/true,
|
|
[](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
|
|
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
|
|
return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
|
|
[](SDBMDirectExpr e) { return e; });
|
|
if (constant != 0)
|
|
return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
|
|
SDBMConstantExpr::get(lhs.getDialect(), constant));
|
|
return lhs;
|
|
}
|
|
|
|
// Build a difference expression given a direct expression and a negation
|
|
// expression.
|
|
static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
|
|
// Fold (x + C) - (x + D) = C - D.
|
|
if (lhs.getTerm() == rhs.getVar().getTerm())
|
|
return SDBMConstantExpr::get(
|
|
lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant());
|
|
|
|
return SDBMDiffExpr::get(
|
|
addConstantAndSink<SDBMDirectExpr>(lhs, -rhs.getVar().getConstant(),
|
|
/*negated=*/false,
|
|
[](SDBMDirectExpr e) { return e; }),
|
|
rhs.getVar().getTerm());
|
|
}
|
|
|
|
// Try folding an expression (lhs + rhs) where at least one of the operands
|
|
// contains a negated variable, i.e. is a negation or a difference expression.
|
|
static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
|
|
// If exactly one of LHS, RHS is a negation expression, we can construct
|
|
// a difference expression, which is a special kind in SDBM.
|
|
auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
|
|
auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
|
|
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
|
|
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
|
|
|
|
if (lhsDirect && rhsNeg)
|
|
return buildDiffExpr(lhsDirect, rhsNeg);
|
|
if (lhsNeg && rhsDirect)
|
|
return buildDiffExpr(rhsDirect, lhsNeg);
|
|
|
|
// If a subexpression appears in a diff expression on the LHS(RHS) of a
|
|
// sum expression where it also appears on the RHS(LHS) with the opposite
|
|
// sign, we can simplify it away and obtain the SDBM form.
|
|
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
|
|
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
|
|
|
|
// -(x + A) + ((x + B) - y) = -(y + (A - B))
|
|
if (lhsNeg && rhsDiff &&
|
|
lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) {
|
|
int64_t constant =
|
|
lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant();
|
|
// RHS of the diff is a term expression, its sum with a constant is a direct
|
|
// expression.
|
|
return SDBMNegExpr::get(
|
|
addConstant(rhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
|
|
}
|
|
|
|
// (x + A) + ((y + B) - x) = (y + B) + A.
|
|
if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS())
|
|
return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant());
|
|
|
|
// ((x + A) - y) + (-(x + B)) = -(y + (B - A)).
|
|
if (lhsDiff && rhsNeg &&
|
|
lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) {
|
|
int64_t constant =
|
|
rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant();
|
|
// RHS of the diff is a term expression, its sum with a constant is a direct
|
|
// expression.
|
|
return SDBMNegExpr::get(
|
|
addConstant(lhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
|
|
}
|
|
|
|
// ((x + A) - y) + (y + B) = (x + A) + B.
|
|
if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS())
|
|
return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant());
|
|
|
|
return {};
|
|
}
|
|
|
|
Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
|
struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
|
|
SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
|
|
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
|
if (!lhs || !rhs)
|
|
return {};
|
|
|
|
// In a "add" AffineExpr, the constant always appears on the right. If
|
|
// there were two constants, they would have been folded away.
|
|
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
|
|
|
|
// If RHS is a constant, we can always extend the SDBM expression to
|
|
// include it by sinking the constant into the nearest sum expression.
|
|
if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
|
|
int64_t constant = rhsConstant.getValue();
|
|
auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
|
|
assert(varying && "unexpected uncanonicalized sum of constants");
|
|
return addConstant(varying, constant);
|
|
}
|
|
|
|
// Try building a difference expression if one of the values is negated,
|
|
// or check if a difference on either hand side cancels out the outer term
|
|
// so as to remain correct within SDBM. Return null otherwise.
|
|
return foldSumDiff(lhs, rhs);
|
|
}
|
|
|
|
SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
|
|
// Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
|
|
AffineExprMatcher x, C;
|
|
AffineExprMatcher pattern = (x.floorDiv(C)) * C;
|
|
if (pattern.match(expr)) {
|
|
if (SDBMExpr converted = visit(x.matched())) {
|
|
if (auto varConverted = converted.dyn_cast<SDBMTermExpr>())
|
|
// TODO: return varConverted.stripe(C.getConstantValue());
|
|
return SDBMStripeExpr::get(
|
|
varConverted,
|
|
SDBMConstantExpr::get(dialect,
|
|
C.getMatchedConstantValue().getValue()));
|
|
}
|
|
}
|
|
|
|
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
|
if (!lhs || !rhs)
|
|
return {};
|
|
|
|
// In a "mul" AffineExpr, the constant always appears on the right. If
|
|
// there were two constants, they would have been folded away.
|
|
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
|
|
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
|
if (!rhsConstant)
|
|
return {};
|
|
|
|
// The only supported "multiplication" expression is an SDBM is dimension
|
|
// negation, that is a product of dimension and constant -1.
|
|
if (rhsConstant.getValue() != -1)
|
|
return {};
|
|
|
|
if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
|
|
return SDBMNegExpr::get(lhsVar);
|
|
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
|
|
return SDBMNegator().visitDiff(lhsDiff);
|
|
|
|
// Other multiplications are not allowed in SDBM.
|
|
return {};
|
|
}
|
|
|
|
SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
|
|
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
|
if (!lhs || !rhs)
|
|
return {};
|
|
|
|
// 'mod' can only be converted to SDBM if its LHS is a direct expression
|
|
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
|
|
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
|
auto lhsVar = lhs.dyn_cast<SDBMDirectExpr>();
|
|
if (!lhsVar || !rhsConstant)
|
|
return {};
|
|
return SDBMDiffExpr::get(lhsVar,
|
|
SDBMStripeExpr::get(lhsVar, rhsConstant));
|
|
}
|
|
|
|
// `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
|
|
SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
|
|
SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
|
|
|
|
// Dimensions, symbols and constants are converted trivially.
|
|
SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
|
|
return SDBMConstantExpr::get(dialect, expr.getValue());
|
|
}
|
|
SDBMExpr visitDimExpr(AffineDimExpr expr) {
|
|
return SDBMDimExpr::get(dialect, expr.getPosition());
|
|
}
|
|
SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
|
|
return SDBMSymbolExpr::get(dialect, expr.getPosition());
|
|
}
|
|
|
|
SDBMDialect *dialect;
|
|
} converter;
|
|
converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
|
|
|
|
if (auto result = converter.visit(affine))
|
|
return result;
|
|
return None;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMDiffExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
|
|
assert(lhs && "expected SDBM dimension");
|
|
assert(rhs && "expected SDBM dimension");
|
|
|
|
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
|
return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
|
|
}
|
|
|
|
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
|
|
return static_cast<ImplType *>(impl)->lhs;
|
|
}
|
|
|
|
SDBMTermExpr SDBMDiffExpr::getRHS() const {
|
|
return static_cast<ImplType *>(impl)->rhs;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMDirectExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMTermExpr SDBMDirectExpr::getTerm() {
|
|
if (auto sum = dyn_cast<SDBMSumExpr>())
|
|
return sum.getLHS();
|
|
return cast<SDBMTermExpr>();
|
|
}
|
|
|
|
int64_t SDBMDirectExpr::getConstant() {
|
|
if (auto sum = dyn_cast<SDBMSumExpr>())
|
|
return sum.getRHS().getValue();
|
|
return 0;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMStripeExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
|
|
SDBMConstantExpr stripeFactor) {
|
|
assert(var && "expected SDBM variable expression");
|
|
assert(stripeFactor && "expected non-null stripe factor");
|
|
if (stripeFactor.getValue() <= 0)
|
|
llvm::report_fatal_error("non-positive stripe factor");
|
|
|
|
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
|
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
|
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
|
|
stripeFactor);
|
|
}
|
|
|
|
SDBMDirectExpr SDBMStripeExpr::getLHS() const {
|
|
if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
|
|
return lhs.cast<SDBMDirectExpr>();
|
|
return {};
|
|
}
|
|
|
|
SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
|
|
return static_cast<ImplType *>(impl)->rhs;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMInputExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
unsigned SDBMInputExpr::getPosition() const {
|
|
return static_cast<ImplType *>(impl)->position;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMDimExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
|
|
assert(dialect && "expected non-null dialect");
|
|
|
|
auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
|
|
storage->dialect = dialect;
|
|
};
|
|
|
|
StorageUniquer &uniquer = dialect->getUniquer();
|
|
return uniquer.get<detail::SDBMTermExprStorage>(
|
|
assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMSymbolExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
|
|
assert(dialect && "expected non-null dialect");
|
|
|
|
auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
|
|
storage->dialect = dialect;
|
|
};
|
|
|
|
StorageUniquer &uniquer = dialect->getUniquer();
|
|
return uniquer.get<detail::SDBMTermExprStorage>(
|
|
assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMConstantExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
|
|
assert(dialect && "expected non-null dialect");
|
|
|
|
auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
|
|
storage->dialect = dialect;
|
|
};
|
|
|
|
StorageUniquer &uniquer = dialect->getUniquer();
|
|
return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
|
|
}
|
|
|
|
int64_t SDBMConstantExpr::getValue() const {
|
|
return static_cast<ImplType *>(impl)->constant;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SDBMNegExpr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
|
|
assert(var && "expected non-null SDBM direct expression");
|
|
|
|
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
|
return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
|
|
}
|
|
|
|
SDBMDirectExpr SDBMNegExpr::getVar() const {
|
|
return static_cast<ImplType *>(impl)->expr;
|
|
}
|
|
|
|
SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) {
|
|
if (auto folded = foldSumDiff(lhs, rhs))
|
|
return folded;
|
|
assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
|
|
"a sum of negated expressions is a negation of a sum of variables and "
|
|
"not a correct SDBM");
|
|
|
|
// Fold (x - y) + (y - x) = 0.
|
|
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
|
|
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
|
|
if (lhsDiff && rhsDiff) {
|
|
if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
|
|
lhsDiff.getRHS() == rhsDiff.getLHS())
|
|
return SDBMConstantExpr::get(lhs.getDialect(), 0);
|
|
}
|
|
|
|
// If LHS is a constant and RHS is not, swap the order to get into a supported
|
|
// sum case. From now on, RHS must be a constant.
|
|
auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
|
|
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
|
if (!rhsConstant && lhsConstant) {
|
|
std::swap(lhs, rhs);
|
|
std::swap(lhsConstant, rhsConstant);
|
|
}
|
|
assert(rhsConstant && "at least one operand must be a constant");
|
|
|
|
// Constant-fold if LHS is also a constant.
|
|
if (lhsConstant)
|
|
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
|
|
rhsConstant.getValue());
|
|
return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
|
|
}
|
|
|
|
SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) {
|
|
// Fold x - x == 0.
|
|
if (lhs == rhs)
|
|
return SDBMConstantExpr::get(lhs.getDialect(), 0);
|
|
|
|
// LHS and RHS may be constants.
|
|
auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
|
|
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
|
|
|
// Constant fold if both LHS and RHS are constants.
|
|
if (lhsConstant && rhsConstant)
|
|
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
|
|
rhsConstant.getValue());
|
|
|
|
// Replace a difference with a sum with a negated value if one of LHS and RHS
|
|
// is a constant:
|
|
// x - C == x + (-C);
|
|
// C - x == -x + C.
|
|
// This calls into operator+ for further simplification.
|
|
if (rhsConstant)
|
|
return lhs + (-rhsConstant);
|
|
if (lhsConstant)
|
|
return -rhs + lhsConstant;
|
|
|
|
return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
|
|
}
|
|
|
|
SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) {
|
|
auto constantFactor = factor.cast<SDBMConstantExpr>();
|
|
assert(constantFactor.getValue() > 0 && "non-positive stripe");
|
|
|
|
// Fold x # 1 = x.
|
|
if (constantFactor.getValue() == 1)
|
|
return expr;
|
|
|
|
return SDBMStripeExpr::get(expr.cast<SDBMDirectExpr>(), constantFactor);
|
|
}
|