//===- PassManagerTest.cpp - PassManager unit tests -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Pass/PassManager.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "gtest/gtest.h" using namespace mlir; using namespace mlir::detail; namespace { /// Analysis that operates on any operation. struct GenericAnalysis { GenericAnalysis(Operation *op) : isFunc(isa(op)) {} const bool isFunc; }; /// Analysis that operates on a specific operation. struct OpSpecificAnalysis { OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {} const bool isSecret; }; /// Simple pass to annotate a FuncOp with the results of analysis. /// Note: not using FunctionPass as it skip external functions. struct AnnotateFunctionPass : public PassWrapper> { void runOnOperation() override { FuncOp op = getOperation(); Builder builder(op.getParentOfType()); auto &ga = getAnalysis(); auto &sa = getAnalysis(); op.setAttr("isFunc", builder.getBoolAttr(ga.isFunc)); op.setAttr("isSecret", builder.getBoolAttr(sa.isSecret)); } }; TEST(PassManagerTest, OpSpecificAnalysis) { MLIRContext context; Builder builder(&context); // Create a module with 2 functions. OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); for (StringRef name : {"secret", "not_secret"}) { FuncOp func = FuncOp::create(builder.getUnknownLoc(), name, builder.getFunctionType(llvm::None, llvm::None)); func.setPrivate(); module->push_back(func); } // Instantiate and run our pass. PassManager pm(&context); pm.addNestedPass(std::make_unique()); LogicalResult result = pm.run(module.get()); EXPECT_TRUE(succeeded(result)); // Verify that each function got annotated with expected attributes. for (FuncOp func : module->getOps()) { ASSERT_TRUE(func.getAttr("isFunc").isa()); EXPECT_TRUE(func.getAttr("isFunc").cast().getValue()); bool isSecret = func.getName() == "secret"; ASSERT_TRUE(func.getAttr("isSecret").isa()); EXPECT_EQ(func.getAttr("isSecret").cast().getValue(), isSecret); } } namespace { struct InvalidPass : Pass { InvalidPass() : Pass(TypeID::get(), StringRef("invalid_op")) {} StringRef getName() const override { return "Invalid Pass"; } void runOnOperation() override {} /// A clone method to create a copy of this pass. std::unique_ptr clonePass() const override { return std::make_unique( *static_cast(this)); } }; } // anonymous namespace TEST(PassManagerTest, InvalidPass) { MLIRContext context; // Create a module OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); // Add a single "invalid_op" operation OpBuilder builder(&module->getBodyRegion()); OperationState state(UnknownLoc::get(&context), "invalid_op"); builder.insert(Operation::create(state)); // Register a diagnostic handler to capture the diagnostic so that we can // check it later. std::unique_ptr diagnostic; context.getDiagEngine().registerHandler([&](Diagnostic &diag) { diagnostic.reset(new Diagnostic(std::move(diag))); }); // Instantiate and run our pass. PassManager pm(&context); pm.nest("invalid_op").addPass(std::make_unique()); LogicalResult result = pm.run(module.get()); EXPECT_TRUE(failed(result)); ASSERT_TRUE(diagnostic.get() != nullptr); EXPECT_EQ( diagnostic->str(), "'invalid_op' op trying to schedule a pass on an unregistered operation"); // Check that adding the pass at the top-level triggers a fatal error. ASSERT_DEATH(pm.addPass(std::make_unique()), ""); } } // end namespace