You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
4.1 KiB
125 lines
4.1 KiB
7 months ago
|
//===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Pass/PassManager.h"
|
||
|
#include "mlir/IR/Builders.h"
|
||
|
#include "mlir/IR/BuiltinOps.h"
|
||
|
#include "mlir/Pass/Pass.h"
|
||
|
#include "gtest/gtest.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace mlir::detail;
|
||
|
|
||
|
namespace {
|
||
|
/// Analysis that operates on any operation.
|
||
|
struct GenericAnalysis {
|
||
|
GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
|
||
|
const bool isFunc;
|
||
|
};
|
||
|
|
||
|
/// Analysis that operates on a specific operation.
|
||
|
struct OpSpecificAnalysis {
|
||
|
OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
|
||
|
const bool isSecret;
|
||
|
};
|
||
|
|
||
|
/// Simple pass to annotate a FuncOp with the results of analysis.
|
||
|
/// Note: not using FunctionPass as it skip external functions.
|
||
|
struct AnnotateFunctionPass
|
||
|
: public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
|
||
|
void runOnOperation() override {
|
||
|
FuncOp op = getOperation();
|
||
|
Builder builder(op.getParentOfType<ModuleOp>());
|
||
|
|
||
|
auto &ga = getAnalysis<GenericAnalysis>();
|
||
|
auto &sa = getAnalysis<OpSpecificAnalysis>();
|
||
|
|
||
|
op.setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
|
||
|
op.setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
|
||
|
}
|
||
|
};
|
||
|
|
||
|
TEST(PassManagerTest, OpSpecificAnalysis) {
|
||
|
MLIRContext context;
|
||
|
Builder builder(&context);
|
||
|
|
||
|
// Create a module with 2 functions.
|
||
|
OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
|
||
|
for (StringRef name : {"secret", "not_secret"}) {
|
||
|
FuncOp func =
|
||
|
FuncOp::create(builder.getUnknownLoc(), name,
|
||
|
builder.getFunctionType(llvm::None, llvm::None));
|
||
|
func.setPrivate();
|
||
|
module->push_back(func);
|
||
|
}
|
||
|
|
||
|
// Instantiate and run our pass.
|
||
|
PassManager pm(&context);
|
||
|
pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
|
||
|
LogicalResult result = pm.run(module.get());
|
||
|
EXPECT_TRUE(succeeded(result));
|
||
|
|
||
|
// Verify that each function got annotated with expected attributes.
|
||
|
for (FuncOp func : module->getOps<FuncOp>()) {
|
||
|
ASSERT_TRUE(func.getAttr("isFunc").isa<BoolAttr>());
|
||
|
EXPECT_TRUE(func.getAttr("isFunc").cast<BoolAttr>().getValue());
|
||
|
|
||
|
bool isSecret = func.getName() == "secret";
|
||
|
ASSERT_TRUE(func.getAttr("isSecret").isa<BoolAttr>());
|
||
|
EXPECT_EQ(func.getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
namespace {
|
||
|
struct InvalidPass : Pass {
|
||
|
InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
|
||
|
StringRef getName() const override { return "Invalid Pass"; }
|
||
|
void runOnOperation() override {}
|
||
|
|
||
|
/// A clone method to create a copy of this pass.
|
||
|
std::unique_ptr<Pass> clonePass() const override {
|
||
|
return std::make_unique<InvalidPass>(
|
||
|
*static_cast<const InvalidPass *>(this));
|
||
|
}
|
||
|
};
|
||
|
} // anonymous namespace
|
||
|
|
||
|
TEST(PassManagerTest, InvalidPass) {
|
||
|
MLIRContext context;
|
||
|
|
||
|
// Create a module
|
||
|
OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
|
||
|
|
||
|
// Add a single "invalid_op" operation
|
||
|
OpBuilder builder(&module->getBodyRegion());
|
||
|
OperationState state(UnknownLoc::get(&context), "invalid_op");
|
||
|
builder.insert(Operation::create(state));
|
||
|
|
||
|
// Register a diagnostic handler to capture the diagnostic so that we can
|
||
|
// check it later.
|
||
|
std::unique_ptr<Diagnostic> diagnostic;
|
||
|
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
|
||
|
diagnostic.reset(new Diagnostic(std::move(diag)));
|
||
|
});
|
||
|
|
||
|
// Instantiate and run our pass.
|
||
|
PassManager pm(&context);
|
||
|
pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
|
||
|
LogicalResult result = pm.run(module.get());
|
||
|
EXPECT_TRUE(failed(result));
|
||
|
ASSERT_TRUE(diagnostic.get() != nullptr);
|
||
|
EXPECT_EQ(
|
||
|
diagnostic->str(),
|
||
|
"'invalid_op' op trying to schedule a pass on an unregistered operation");
|
||
|
|
||
|
// Check that adding the pass at the top-level triggers a fatal error.
|
||
|
ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
|
||
|
}
|
||
|
|
||
|
} // end namespace
|