//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // OpInterfacesGen generates definitions for operation interfaces. // //===----------------------------------------------------------------------===// #include "DocGenUtilities.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" using namespace mlir; using mlir::tblgen::Interface; using mlir::tblgen::InterfaceMethod; using mlir::tblgen::OpInterface; /// Emit a string corresponding to a C++ type, followed by a space if necessary. static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { type = type.trim(); os << type; if (type.back() != '&' && type.back() != '*') os << " "; return os; } /// Emit the method name and argument list for the given method. If 'addThisArg' /// is true, then an argument is added to the beginning of the argument list for /// the concrete value. static void emitMethodNameAndArgs(const InterfaceMethod &method, raw_ostream &os, StringRef valueType, bool addThisArg, bool addConst) { os << method.getName() << '('; if (addThisArg) emitCPPType(valueType, os) << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); llvm::interleaveComma(method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { os << arg.type << " " << arg.name; }); os << ')'; if (addConst) os << " const"; } /// Get an array of all OpInterface definitions but exclude those subclassing /// "DeclareOpInterfaceMethods". static std::vector getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) { std::vector defs = recordKeeper.getAllDerivedDefinitions("OpInterface"); llvm::erase_if(defs, [](const llvm::Record *def) { return def->isSubClassOf("DeclareOpInterfaceMethods"); }); return defs; } namespace { /// This struct is the base generator used when processing tablegen interfaces. class InterfaceGenerator { public: bool emitInterfaceDefs(); bool emitInterfaceDecls(); bool emitInterfaceDocs(); protected: InterfaceGenerator(std::vector &&defs, raw_ostream &os) : defs(std::move(defs)), os(os) {} void emitConceptDecl(Interface &interface); void emitModelDecl(Interface &interface); void emitModelMethodsDef(Interface &interface); void emitTraitDecl(Interface &interface, StringRef interfaceName, StringRef interfaceTraitsName); void emitInterfaceDecl(Interface interface); /// The set of interface records to emit. std::vector defs; // The stream to emit to. raw_ostream &os; /// The C++ value type of the interface, e.g. Operation*. StringRef valueType; /// The C++ base interface type. StringRef interfaceBaseType; /// The name of the typename for the value template. StringRef valueTemplate; /// The format context to use for methods. tblgen::FmtContext nonStaticMethodFmt; tblgen::FmtContext traitMethodFmt; }; /// A specialized generator for attribute interfaces. struct AttrInterfaceGenerator : public InterfaceGenerator { AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"), os) { valueType = "::mlir::Attribute"; interfaceBaseType = "AttributeInterface"; valueTemplate = "ConcreteAttr"; StringRef castCode = "(tablegen_opaque_val.cast())"; nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode); traitMethodFmt.addSubst("_attr", "(*static_cast(this))"); } }; /// A specialized generator for operation interfaces. struct OpInterfaceGenerator : public InterfaceGenerator { OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) { valueType = "::mlir::Operation *"; interfaceBaseType = "OpInterface"; valueTemplate = "ConcreteOp"; StringRef castCode = "(llvm::cast(tablegen_opaque_val))"; nonStaticMethodFmt.withOp(castCode).withSelf(castCode); traitMethodFmt.withOp("(*static_cast(this))"); } }; /// A specialized generator for type interfaces. struct TypeInterfaceGenerator : public InterfaceGenerator { TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"), os) { valueType = "::mlir::Type"; interfaceBaseType = "TypeInterface"; valueTemplate = "ConcreteType"; StringRef castCode = "(tablegen_opaque_val.cast())"; nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode); traitMethodFmt.addSubst("_type", "(*static_cast(this))"); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // GEN: Interface definitions //===----------------------------------------------------------------------===// static void emitInterfaceDef(Interface interface, StringRef valueType, raw_ostream &os) { StringRef interfaceName = interface.getName(); StringRef cppNamespace = interface.getCppNamespace(); cppNamespace.consume_front("::"); // Insert the method definitions. bool isOpInterface = isa(interface); for (auto &method : interface.getMethods()) { emitCPPType(method.getReturnType(), os); if (!cppNamespace.empty()) os << cppNamespace << "::"; os << interfaceName << "::"; emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, /*addConst=*/!isOpInterface); // Forward to the method on the concrete operation type. os << " {\n return getImpl()->" << method.getName() << '('; if (!method.isStatic()) { os << (isOpInterface ? "getOperation()" : "*this"); os << (method.arg_empty() ? "" : ", "); } llvm::interleaveComma( method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n }\n"; } } bool InterfaceGenerator::emitInterfaceDefs() { llvm::emitSourceFileHeader("Interface Definitions", os); for (const auto *def : defs) emitInterfaceDef(Interface(def), valueType, os); return false; } //===----------------------------------------------------------------------===// // GEN: Interface declarations //===----------------------------------------------------------------------===// void InterfaceGenerator::emitConceptDecl(Interface &interface) { os << " struct Concept {\n"; // Insert each of the pure virtual concept methods. for (auto &method : interface.getMethods()) { os << " "; emitCPPType(method.getReturnType(), os); os << "(*" << method.getName() << ")("; if (!method.isStatic()) emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); llvm::interleaveComma( method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { os << arg.type; }); os << ");\n"; } os << " };\n"; } void InterfaceGenerator::emitModelDecl(Interface &interface) { os << " template\n"; os << " class Model : public Concept {\n public:\n"; os << " Model() : Concept{"; llvm::interleaveComma( interface.getMethods(), os, [&](const InterfaceMethod &method) { os << method.getName(); }); os << "} {}\n\n"; // Insert each of the virtual method overrides. for (auto &method : interface.getMethods()) { emitCPPType(method.getReturnType(), os << " static inline "); emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); os << ";\n"; } os << " };\n"; } void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { for (auto &method : interface.getMethods()) { os << "template\n"; emitCPPType(method.getReturnType(), os); os << "detail::" << interface.getName() << "InterfaceTraits::Model<" << valueTemplate << ">::"; emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); os << " {\n "; // Check for a provided body to the function. if (Optional body = method.getBody()) { if (method.isStatic()) os << body->trim(); else os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt); os << "\n}\n"; continue; } // Forward to the method on the concrete operation type. if (method.isStatic()) os << "return " << valueTemplate << "::"; else os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt); // Add the arguments to the call. os << method.getName() << '('; llvm::interleaveComma( method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n}\n"; } } void InterfaceGenerator::emitTraitDecl(Interface &interface, StringRef interfaceName, StringRef interfaceTraitsName) { os << llvm::formatv(" template \n" " struct {0}Trait : public ::mlir::{2}<{0}," " detail::{1}>::Trait<{3}> {{\n", interfaceName, interfaceTraitsName, interfaceBaseType, valueTemplate); // Insert the default implementation for any methods. bool isOpInterface = isa(interface); for (auto &method : interface.getMethods()) { // Flag interface methods named verifyTrait. if (method.getName() == "verifyTrait") PrintFatalError( formatv("'verifyTrait' method cannot be specified as interface " "method for '{0}'; use the 'verify' field instead", interfaceName)); auto defaultImpl = method.getDefaultImplementation(); if (!defaultImpl) continue; os << " " << (method.isStatic() ? "static " : ""); emitCPPType(method.getReturnType(), os); emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, /*addConst=*/!isOpInterface); os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt) << "\n }\n"; } if (auto verify = interface.getVerify()) { assert(isa(interface) && "only OpInterface supports 'verify'"); tblgen::FmtContext verifyCtx; verifyCtx.withOp("op"); os << " static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) " "{\n " << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n"; } if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; // Emit a utility wrapper trait class. os << llvm::formatv(" template \n" " struct Trait : public {0}Trait<{1}> {{};\n", interfaceName, valueTemplate); } void InterfaceGenerator::emitInterfaceDecl(Interface interface) { llvm::SmallVector namespaces; llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); for (StringRef ns : namespaces) os << "namespace " << ns << " {\n"; StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); // Emit a forward declaration of the interface class so that it becomes usable // in the signature of its methods. os << "class " << interfaceName << ";\n"; // Emit the traits struct containing the concept and model declarations. os << "namespace detail {\n" << "struct " << interfaceTraitsName << " {\n"; emitConceptDecl(interface); emitModelDecl(interface); os << "};\n} // end namespace detail\n"; // Emit the main interface class declaration. os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" "public:\n" " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); // Emit the derived trait for the interface. emitTraitDecl(interface, interfaceName, interfaceTraitsName); // Insert the method declarations. bool isOpInterface = isa(interface); for (auto &method : interface.getMethods()) { emitCPPType(method.getReturnType(), os << " "); emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, /*addConst=*/!isOpInterface); os << ";\n"; } // Emit any extra declarations. if (Optional extraDecls = interface.getExtraClassDeclaration()) os << *extraDecls << "\n"; os << "};\n"; emitModelMethodsDef(interface); for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; } bool InterfaceGenerator::emitInterfaceDecls() { llvm::emitSourceFileHeader("Interface Declarations", os); for (const auto *def : defs) emitInterfaceDecl(Interface(def)); return false; } //===----------------------------------------------------------------------===// // GEN: Interface documentation //===----------------------------------------------------------------------===// static void emitInterfaceDoc(const llvm::Record &interfaceDef, raw_ostream &os) { Interface interface(&interfaceDef); // Emit the interface name followed by the description. os << "## " << interface.getName() << " (" << interfaceDef.getName() << ")"; if (auto description = interface.getDescription()) mlir::tblgen::emitDescription(*description, os); // Emit the methods required by the interface. os << "\n### Methods:\n"; for (const auto &method : interface.getMethods()) { // Emit the method name. os << "#### `" << method.getName() << "`\n\n```c++\n"; // Emit the method signature. if (method.isStatic()) os << "static "; emitCPPType(method.getReturnType(), os) << method.getName() << '('; llvm::interleaveComma(method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { emitCPPType(arg.type, os) << arg.name; }); os << ");\n```\n"; // Emit the description. if (auto description = method.getDescription()) mlir::tblgen::emitDescription(*description, os); // If the body is not provided, this method must be provided by the user. if (!method.getBody()) os << "\nNOTE: This method *must* be implemented by the user.\n\n"; } } bool InterfaceGenerator::emitInterfaceDocs() { os << "\n"; os << "# " << interfaceBaseType << " definitions\n"; for (const auto *def : defs) emitInterfaceDoc(*def, os); return false; } //===----------------------------------------------------------------------===// // GEN: Interface registration hooks //===----------------------------------------------------------------------===// namespace { template struct InterfaceGenRegistration { InterfaceGenRegistration(StringRef genArg) : genDeclArg(("gen-" + genArg + "-interface-decls").str()), genDefArg(("gen-" + genArg + "-interface-defs").str()), genDocArg(("gen-" + genArg + "-interface-docs").str()), genDecls(genDeclArg, "Generate interface declarations", [](const llvm::RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDecls(); }), genDefs(genDefArg, "Generate interface definitions", [](const llvm::RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDefs(); }), genDocs(genDocArg, "Generate interface documentation", [](const llvm::RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDocs(); }) {} std::string genDeclArg, genDefArg, genDocArg; mlir::GenRegistration genDecls, genDefs, genDocs; }; } // end anonymous namespace static InterfaceGenRegistration attrGen("attr"); static InterfaceGenRegistration opGen("op"); static InterfaceGenRegistration typeGen("type");