//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; /// Custom constraint invoked from PDL. static LogicalResult customSingleEntityConstraint(PDLValue value, ArrayAttr constantParams, PatternRewriter &rewriter) { Operation *rootOp = value.cast(); return success(rootOp->getName().getStringRef() == "test.op"); } static LogicalResult customMultiEntityConstraint(ArrayRef values, ArrayAttr constantParams, PatternRewriter &rewriter) { return customSingleEntityConstraint(values[1], constantParams, rewriter); } // Custom creator invoked from PDL. static PDLValue customCreate(ArrayRef args, ArrayAttr constantParams, PatternRewriter &rewriter) { return rewriter.createOperation( OperationState(args[0].cast()->getLoc(), "test.success")); } /// Custom rewriter invoked from PDL. static void customRewriter(Operation *root, ArrayRef args, ArrayAttr constantParams, PatternRewriter &rewriter) { OperationState successOpState(root->getLoc(), "test.success"); successOpState.addOperands(args[0].cast()); successOpState.addAttribute("constantParams", constantParams); rewriter.createOperation(successOpState); rewriter.eraseOp(root); } namespace { struct TestPDLByteCodePass : public PassWrapper> { void runOnOperation() final { ModuleOp module = getOperation(); // The test cases are encompassed via two modules, one containing the // patterns and one containing the operations to rewrite. ModuleOp patternModule = module.lookupSymbol("patterns"); ModuleOp irModule = module.lookupSymbol("ir"); if (!patternModule || !irModule) return; // Process the pattern module. patternModule.getOperation()->remove(); PDLPatternModule pdlPattern(patternModule); pdlPattern.registerConstraintFunction("multi_entity_constraint", customMultiEntityConstraint); pdlPattern.registerConstraintFunction("single_entity_constraint", customSingleEntityConstraint); pdlPattern.registerCreateFunction("creator", customCreate); pdlPattern.registerRewriteFunction("rewriter", customRewriter); OwningRewritePatternList patternList(std::move(pdlPattern)); // Invoke the pattern driver with the provided patterns. (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), std::move(patternList)); } }; } // end anonymous namespace namespace mlir { namespace test { void registerTestPDLByteCodePass() { PassRegistration("test-pdl-bytecode-pass", "Test PDL ByteCode functionality"); } } // namespace test } // namespace mlir