//===- TestLoopFusion.cpp - Test loop fusion ------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to test various loop fusion utility functions. // //===----------------------------------------------------------------------===// #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopFusionUtils.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #define DEBUG_TYPE "test-loop-fusion" using namespace mlir; static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); static llvm::cl::opt clTestDependenceCheck( "test-loop-fusion-dependence-check", llvm::cl::desc("Enable testing of loop fusion dependence check"), llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clTestSliceComputation( "test-loop-fusion-slice-computation", llvm::cl::desc("Enable testing of loop fusion slice computation"), llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clTestLoopFusionTransformation( "test-loop-fusion-transformation", llvm::cl::desc("Enable testing of loop fusion transformation"), llvm::cl::cat(clOptionsCategory)); namespace { struct TestLoopFusion : public PassWrapper { void runOnFunction() override; }; } // end anonymous namespace // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths // in range ['loopDepth' + 1, 'maxLoopDepth']. // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists. // Returns false as IR is not transformed. static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp, unsigned i, unsigned j, unsigned loopDepth, unsigned maxLoopDepth) { mlir::ComputationSliceState sliceUnion; for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { FusionResult result = mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion); if (result.value == FusionResult::FailBlockDependence) { srcForOp->emitRemark("block-level dependence preventing" " fusion of loop nest ") << i << " into loop nest " << j << " at depth " << loopDepth; } } return false; } // Returns the index of 'op' in its block. static unsigned getBlockIndex(Operation &op) { unsigned index = 0; for (auto &opX : *op.getBlock()) { if (&op == &opX) break; ++index; } return index; } // Returns a string representation of 'sliceUnion'. static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) { std::string result; llvm::raw_string_ostream os(result); // Slice insertion point format [loop-depth, operation-block-index] unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint); unsigned ipb = getBlockIndex(*sliceUnion.insertPoint); os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb) << ")"; assert(sliceUnion.lbs.size() == sliceUnion.ubs.size()); os << " loop bounds: "; for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) { os << '['; sliceUnion.lbs[k].print(os); os << ", "; sliceUnion.ubs[k].print(os); os << "] "; } return os.str(); } // Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths // in range ['loopDepth' + 1, 'maxLoopDepth']. // Emits a string representation of the slice union as a remark on 'loops[j]'. // Returns false as IR is not transformed. static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB, unsigned i, unsigned j, unsigned loopDepth, unsigned maxLoopDepth) { for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { mlir::ComputationSliceState sliceUnion; FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); if (result.value == FusionResult::Success) { forOpB->emitRemark("slice (") << " src loop: " << i << ", dst loop: " << j << ", depth: " << d << " : " << getSliceStr(sliceUnion) << ")"; } } return false; } // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range // ['loopDepth' + 1, 'maxLoopDepth']. // Returns true if loops were successfully fused, false otherwise. static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB, unsigned i, unsigned j, unsigned loopDepth, unsigned maxLoopDepth) { for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { mlir::ComputationSliceState sliceUnion; FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); if (result.value == FusionResult::Success) { mlir::fuseLoops(forOpA, forOpB, sliceUnion); // Note: 'forOpA' is removed to simplify test output. A proper loop // fusion pass should check the data dependence graph and run memref // region analysis to ensure removing 'forOpA' is safe. forOpA.erase(); return true; } } return false; } using LoopFunc = function_ref; // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. // If 'return_on_change' is true, returns on first invocation of 'fn' which // returns true. static bool iterateLoops(ArrayRef> depthToLoops, LoopFunc fn, bool return_on_change = false) { bool changed = false; for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end; ++loopDepth) { auto &loops = depthToLoops[loopDepth]; unsigned numLoops = loops.size(); for (unsigned j = 0; j < numLoops; ++j) { for (unsigned k = 0; k < numLoops; ++k) { if (j != k) changed |= fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size()); if (changed && return_on_change) return true; } } } return changed; } void TestLoopFusion::runOnFunction() { std::vector> depthToLoops; if (clTestLoopFusionTransformation) { // Run loop fusion until a fixed point is reached. do { depthToLoops.clear(); // Gather all AffineForOps by loop depth. gatherLoops(getFunction(), depthToLoops); // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'. } while (iterateLoops(depthToLoops, testLoopFusionTransformation, /*return_on_change=*/true)); return; } // Gather all AffineForOps by loop depth. gatherLoops(getFunction(), depthToLoops); // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. if (clTestDependenceCheck) iterateLoops(depthToLoops, testDependenceCheck); if (clTestSliceComputation) iterateLoops(depthToLoops, testSliceComputation); } namespace mlir { namespace test { void registerTestLoopFusion() { PassRegistration("test-loop-fusion", "Tests loop fusion utility functions."); } } // namespace test } // namespace mlir