//===- Pass.cpp - MLIR pass registration generator ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // PassCAPIGen uses the description of passes to generate C API for the passes. // //===----------------------------------------------------------------------===// #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Pass.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" using namespace mlir; using namespace mlir::tblgen; static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl"); static llvm::cl::opt groupName("prefix", llvm::cl::desc("The prefix to use for this group of passes. The " "form will be mlirCreate, the " "prefix can avoid conflicts across libraries."), llvm::cl::cat(passGenCat)); const char *const passDecl = R"( /* Create {0} Pass. */ MLIR_CAPI_EXPORTED MlirPass mlirCreate{0}{1}(); MLIR_CAPI_EXPORTED void mlirRegister{0}{1}(); )"; const char *const fileHeader = R"( /* Autogenerated by mlir-tblgen; don't manually edit. */ #include "mlir-c/Pass.h" #ifdef __cplusplus extern "C" { #endif )"; const char *const fileFooter = R"( #ifdef __cplusplus } #endif )"; /// Emit TODO static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) { os << fileHeader; os << "// Registration for the entire group\n"; os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName << "Passes();\n\n"; for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { Pass pass(def); StringRef defName = pass.getDef()->getName(); os << llvm::formatv(passDecl, groupName, defName); } os << fileFooter; return false; } const char *const passCreateDef = R"( MlirPass mlirCreate{0}{1}() { return wrap({2}.release()); } void mlirRegister{0}{1}() { register{1}Pass(); } )"; /// {0}: The name of the pass group. const char *const passGroupRegistrationCode = R"( //===----------------------------------------------------------------------===// // {0} Group Registration //===----------------------------------------------------------------------===// void mlirRegister{0}Passes() {{ register{0}Passes(); } )"; static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) { os << "/* Autogenerated by mlir-tblgen; don't manually edit. */"; os << llvm::formatv(passGroupRegistrationCode, groupName); for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { Pass pass(def); StringRef defName = pass.getDef()->getName(); os << llvm::formatv(passCreateDef, groupName, defName, pass.getConstructor()); } return false; } static mlir::GenRegistration genCAPIHeader("gen-pass-capi-header", "Generate pass C API header", &emitCAPIHeader); static mlir::GenRegistration genCAPIImpl("gen-pass-capi-impl", "Generate pass C API implementation", &emitCAPIImpl);