//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SDBM/SDBM.h" #include "mlir/Dialect/SDBM/SDBMDialect.h" #include "mlir/Dialect/SDBM/SDBMExpr.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/MLIRContext.h" #include "gtest/gtest.h" #include "llvm/ADT/DenseSet.h" using namespace mlir; static MLIRContext *ctx() { static thread_local MLIRContext context; context.getOrLoadDialect(); return &context; } static SDBMDialect *dialect() { static thread_local SDBMDialect *d = nullptr; if (!d) { d = ctx()->getOrLoadDialect(); } return d; } static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); } static SDBMExpr symb(unsigned pos) { return SDBMSymbolExpr::get(dialect(), pos); } namespace { using namespace mlir::ops_assertions; TEST(SDBMOperators, Add) { auto expr = dim(0) + 42; auto sumExpr = expr.dyn_cast(); ASSERT_TRUE(sumExpr); EXPECT_EQ(sumExpr.getLHS(), dim(0)); EXPECT_EQ(sumExpr.getRHS().getValue(), 42); } TEST(SDBMOperators, AddFolding) { auto constant = SDBMConstantExpr::get(dialect(), 2) + 42; auto constantExpr = constant.dyn_cast(); ASSERT_TRUE(constantExpr); EXPECT_EQ(constantExpr.getValue(), 44); auto expr = (dim(0) + 10) + 32; auto sumExpr = expr.dyn_cast(); ASSERT_TRUE(sumExpr); EXPECT_EQ(sumExpr.getRHS().getValue(), 42); expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)); auto diffExpr = expr.dyn_cast(); ASSERT_TRUE(diffExpr); EXPECT_EQ(diffExpr.getLHS(), dim(0)); EXPECT_EQ(diffExpr.getRHS(), dim(1)); auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0); EXPECT_EQ(inverted, expr); // Check that opposite values cancel each other, and that we elide the zero // constant. expr = dim(0) + 42; auto onlyDim = expr - 42; EXPECT_EQ(onlyDim, dim(0)); // Check that we can sink a constant under a negation. expr = -(dim(0) + 2); auto negatedSum = (expr + 10).dyn_cast(); ASSERT_TRUE(negatedSum); auto sum = negatedSum.getVar().dyn_cast(); ASSERT_TRUE(sum); EXPECT_EQ(sum.getRHS().getValue(), -8); // Sum with zero is the same as the original expression. EXPECT_EQ(dim(0) + 0, dim(0)); // Sum of opposite differences is zero. auto diffOfDiffs = ((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast(); EXPECT_EQ(diffOfDiffs.getValue(), 0); } TEST(SDBMOperators, AddNegativeTerms) { const int64_t A = 7; const int64_t B = -5; auto x = SDBMDimExpr::get(dialect(), 0); auto y = SDBMDimExpr::get(dialect(), 1); // Check the simplification patterns in addition where one of the variables is // cancelled out and the result remains an SDBM. EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B))); EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A); EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A))); EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B); } TEST(SDBMOperators, Diff) { auto expr = dim(0) - dim(1); auto diffExpr = expr.dyn_cast(); ASSERT_TRUE(diffExpr); EXPECT_EQ(diffExpr.getLHS(), dim(0)); EXPECT_EQ(diffExpr.getRHS(), dim(1)); } TEST(SDBMOperators, DiffFolding) { auto constant = SDBMConstantExpr::get(dialect(), 10) - 3; auto constantExpr = constant.dyn_cast(); ASSERT_TRUE(constantExpr); EXPECT_EQ(constantExpr.getValue(), 7); auto expr = dim(0) - 3; auto sumExpr = expr.dyn_cast(); ASSERT_TRUE(sumExpr); EXPECT_EQ(sumExpr.getRHS().getValue(), -3); auto zero = dim(0) - dim(0); constantExpr = zero.dyn_cast(); ASSERT_TRUE(constantExpr); EXPECT_EQ(constantExpr.getValue(), 0); // Check that the constant terms in difference-of-sums are folded. // (d0 - 3) - (d1 - 5) = (d0 + 2) - d1 auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast(); ASSERT_TRUE(diffOfSums); auto lhs = diffOfSums.getLHS().dyn_cast(); ASSERT_TRUE(lhs); EXPECT_EQ(lhs.getLHS(), dim(0)); EXPECT_EQ(lhs.getRHS().getValue(), 2); EXPECT_EQ(diffOfSums.getRHS(), dim(1)); // Check that identical dimensions with opposite signs cancel each other. auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast(); ASSERT_TRUE(cstOnly); EXPECT_EQ(cstOnly.getValue(), 42); // Check that identical terms in sum of diffs cancel out. auto dimOnly = (-dim(0) + (dim(0) - dim(1))); EXPECT_EQ(dimOnly, -dim(1)); dimOnly = (dim(0) - dim(1)) + (-dim(0)); EXPECT_EQ(dimOnly, -dim(1)); dimOnly = (dim(0) - dim(1)) + dim(1); EXPECT_EQ(dimOnly, dim(0)); dimOnly = dim(0) + (dim(1) - dim(0)); EXPECT_EQ(dimOnly, dim(1)); // Top-level zero constant is fine. cstOnly = (-symb(1) + symb(1)).dyn_cast(); ASSERT_TRUE(cstOnly); EXPECT_EQ(cstOnly.getValue(), 0); } TEST(SDBMOperators, Negate) { auto sum = dim(0) + 3; auto negated = (-sum).dyn_cast(); ASSERT_TRUE(negated); EXPECT_EQ(negated.getVar(), sum); } TEST(SDBMOperators, Stripe) { auto expr = stripe(dim(0), 3); auto stripeExpr = expr.dyn_cast(); ASSERT_TRUE(stripeExpr); EXPECT_EQ(stripeExpr.getLHS(), dim(0)); EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3); } TEST(SDBM, RoundTripEqs) { // Build an SDBM defined by // // d0 = s0 # 3 # 5 // s0 # 3 # 5 - d1 + 42 = 0 // // and perform a double round-trip between the "list of equalities" and SDBM // representation. After the first round-trip, the equalities may be // different due to simplification or equivalent substitutions (e.g., the // second equality may become d0 - d1 + 42 = 0). However, there should not // be any further simplification after the second round-trip, // Build the SDBM from a pair of equalities and extract back the lists of // inequalities and equalities. Check that all equalities are properly // detected and none of them decayed into inequalities. auto s = stripe(stripe(symb(0), 3), 5); auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42}); SmallVector eqs, ineqs; sdbm.getSDBMExpressions(dialect(), ineqs, eqs); ASSERT_TRUE(ineqs.empty()); // Do the second round-trip. auto sdbm2 = SDBM::get(llvm::None, eqs); SmallVector eqs2, ineqs2; sdbm2.getSDBMExpressions(dialect(), ineqs2, eqs2); ASSERT_EQ(eqs.size(), eqs2.size()); // Check that the sets of equalities are equal, their order is not relevant. llvm::DenseSet eqSet, eq2Set; eqSet.insert(eqs.begin(), eqs.end()); eq2Set.insert(eqs2.begin(), eqs2.end()); EXPECT_EQ(eqSet, eq2Set); } TEST(SDBMExpr, Constant) { // We can create constants and query them. auto expr = SDBMConstantExpr::get(dialect(), 42); EXPECT_EQ(expr.getValue(), 42); // Two separately created constants with identical values are trivially equal. auto expr2 = SDBMConstantExpr::get(dialect(), 42); EXPECT_EQ(expr, expr2); // Hierarchy is okay. auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); } TEST(SDBMExpr, Dim) { // We can create dimension expressions and query them. auto expr = SDBMDimExpr::get(dialect(), 0); EXPECT_EQ(expr.getPosition(), 0u); // Two separately created dimensions with the same position are trivially // equal. auto expr2 = SDBMDimExpr::get(dialect(), 0); EXPECT_EQ(expr, expr2); // Hierarchy is okay. auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); // Dimensions are not Symbols. auto symbol = SDBMSymbolExpr::get(dialect(), 0); EXPECT_NE(expr, symbol); EXPECT_FALSE(expr.isa()); } TEST(SDBMExpr, Symbol) { // We can create symbol expressions and query them. auto expr = SDBMSymbolExpr::get(dialect(), 0); EXPECT_EQ(expr.getPosition(), 0u); // Two separately created symbols with the same position are trivially equal. auto expr2 = SDBMSymbolExpr::get(dialect(), 0); EXPECT_EQ(expr, expr2); // Hierarchy is okay. auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); // Dimensions are not Symbols. auto symbol = SDBMDimExpr::get(dialect(), 0); EXPECT_NE(expr, symbol); EXPECT_FALSE(expr.isa()); } TEST(SDBMExpr, Stripe) { auto cst2 = SDBMConstantExpr::get(dialect(), 2); auto cst0 = SDBMConstantExpr::get(dialect(), 0); auto var = SDBMSymbolExpr::get(dialect(), 0); // We can create stripe expressions and query them. auto expr = SDBMStripeExpr::get(var, cst2); EXPECT_EQ(expr.getLHS(), var); EXPECT_EQ(expr.getStripeFactor(), cst2); // Two separately created stripe expressions with the same LHS and RHS are // trivially equal. auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(dialect(), 0), cst2); EXPECT_EQ(expr, expr2); // Stripes can be nested. SDBMStripeExpr::get(expr, SDBMConstantExpr::get(dialect(), 4)); // Non-positive stripe factors are not allowed. EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive"); // Stripes can have sums on the LHS. SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2); // Hierarchy is okay. auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); } TEST(SDBMExpr, Neg) { auto cst2 = SDBMConstantExpr::get(dialect(), 2); auto var = SDBMSymbolExpr::get(dialect(), 0); auto stripe = SDBMStripeExpr::get(var, cst2); // We can create negation expressions and query them. auto expr = SDBMNegExpr::get(var); EXPECT_EQ(expr.getVar(), var); auto expr2 = SDBMNegExpr::get(stripe); EXPECT_EQ(expr2.getVar(), stripe); // Neg expressions are trivially comparable. EXPECT_EQ(expr, SDBMNegExpr::get(var)); // Hierarchy is okay. auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); } TEST(SDBMExpr, Sum) { auto cst2 = SDBMConstantExpr::get(dialect(), 2); auto var = SDBMSymbolExpr::get(dialect(), 0); auto stripe = SDBMStripeExpr::get(var, cst2); // We can create sum expressions and query them. auto expr = SDBMSumExpr::get(var, cst2); EXPECT_EQ(expr.getLHS(), var); EXPECT_EQ(expr.getRHS(), cst2); auto expr2 = SDBMSumExpr::get(stripe, cst2); EXPECT_EQ(expr2.getLHS(), stripe); EXPECT_EQ(expr2.getRHS(), cst2); // Sum expressions are trivially comparable. EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2)); // Hierarchy is okay. auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); } TEST(SDBMExpr, Diff) { auto cst2 = SDBMConstantExpr::get(dialect(), 2); auto var = SDBMSymbolExpr::get(dialect(), 0); auto stripe = SDBMStripeExpr::get(var, cst2); // We can create sum expressions and query them. auto expr = SDBMDiffExpr::get(var, stripe); EXPECT_EQ(expr.getLHS(), var); EXPECT_EQ(expr.getRHS(), stripe); auto expr2 = SDBMDiffExpr::get(stripe, var); EXPECT_EQ(expr2.getLHS(), stripe); EXPECT_EQ(expr2.getRHS(), var); // Sum expressions are trivially comparable. EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe)); // Hierarchy is okay. auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); } TEST(SDBMExpr, AffineRoundTrip) { // Build an expression (s0 - s0 # 2) auto cst2 = SDBMConstantExpr::get(dialect(), 2); auto var = SDBMSymbolExpr::get(dialect(), 0); auto stripe = SDBMStripeExpr::get(var, cst2); auto expr = SDBMDiffExpr::get(var, stripe); // Check that it can be converted to AffineExpr and back, i.e. stripe // detection works correctly. Optional roundtripped = SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr()); ASSERT_TRUE(roundtripped.hasValue()); EXPECT_EQ(roundtripped, static_cast(expr)); // Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe // detection supports nested expressions. auto cst5 = SDBMConstantExpr::get(dialect(), 5); auto outerStripe = SDBMStripeExpr::get(stripe, cst5); roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr()); ASSERT_TRUE(roundtripped.hasValue()); EXPECT_EQ(roundtripped, static_cast(outerStripe)); // Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e. // stripe detection supports sum expressions. auto inner = SDBMSumExpr::get(var, cst2); auto stripeSum = SDBMStripeExpr::get(inner, cst5); roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr()); ASSERT_TRUE(roundtripped.hasValue()); EXPECT_EQ(roundtripped, static_cast(stripeSum)); // Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a // deeper expression tree. auto sum = SDBMSumExpr::get(outerStripe, cst2); auto diff = SDBMDiffExpr::get(sum, stripe); roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr()); ASSERT_TRUE(roundtripped.hasValue()); EXPECT_EQ(roundtripped, static_cast(diff)); // Check a nested stripe-sum combination. auto cst7 = SDBMConstantExpr::get(dialect(), 7); auto nestedStripe = SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7); diff = SDBMDiffExpr::get(nestedStripe, stripe); roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr()); ASSERT_TRUE(roundtripped.hasValue()); EXPECT_EQ(roundtripped, static_cast(diff)); } TEST(SDBMExpr, MatchStripeMulPattern) { // Make sure conversion from AffineExpr recognizes multiplicative stripe // pattern (x floordiv B) * B == x # B. auto cst = getAffineConstantExpr(42, ctx()); auto dim = getAffineDimExpr(0, ctx()); auto floor = dim.floorDiv(cst); auto mul = cst * floor; Optional converted = SDBMStripeExpr::tryConvertAffineExpr(mul); ASSERT_TRUE(converted.hasValue()); EXPECT_TRUE(converted->isa()); } TEST(SDBMExpr, NonSDBM) { auto d0 = getAffineDimExpr(0, ctx()); auto d1 = getAffineDimExpr(1, ctx()); auto sum = d0 + d1; auto c2 = getAffineConstantExpr(2, ctx()); auto prod = d0 * c2; auto ceildiv = d1.ceilDiv(c2); // The following are not valid SDBM expressions: // - a sum of two variables EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue()); // - a variable with coefficient other than 1 or -1 EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue()); // - a ceildiv expression EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue()); } } // end namespace