//===- StructsGen.cpp - MLIR struct utility generator ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // StructsGen generates common utility functions for grouping attributes into a // set of structured data. // //===----------------------------------------------------------------------===// #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" using llvm::raw_ostream; using llvm::Record; using llvm::RecordKeeper; using llvm::StringRef; using mlir::tblgen::FmtContext; using mlir::tblgen::StructAttr; static void emitStructClass(const Record &structDef, StringRef structName, llvm::ArrayRef fields, StringRef description, raw_ostream &os) { const char *structInfo = R"( // {0} class {1} : public ::mlir::DictionaryAttr)"; const char *structInfoEnd = R"( { public: using ::mlir::DictionaryAttr::DictionaryAttr; static bool classof(::mlir::Attribute attr); )"; os << formatv(structInfo, description, structName) << structInfoEnd; // Declares a constructor function for the tablegen structure. // TblgenStruct::get(MLIRContext context, Type1 Field1, Type2 Field2, ...); const char *getInfoDecl = " static {0} get(\n"; const char *getInfoDeclArg = " {0} {1},\n"; const char *getInfoDeclEnd = " ::mlir::MLIRContext* context);\n\n"; os << llvm::formatv(getInfoDecl, structName); for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); auto storage = type.getStorageType(); os << llvm::formatv(getInfoDeclArg, storage, name); } os << getInfoDeclEnd; // Declares an accessor for the fields owned by the tablegen structure. // namespace::storage TblgenStruct::field1() const; const char *fieldInfo = R"( {0} {1}() const; )"; for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); auto storage = type.getStorageType(); os << formatv(fieldInfo, storage, name); } os << "};\n\n"; } static void emitStructDecl(const Record &structDef, raw_ostream &os) { StructAttr structAttr(&structDef); StringRef structName = structAttr.getStructClassName(); StringRef cppNamespace = structAttr.getCppNamespace(); StringRef description = structAttr.getDescription(); auto fields = structAttr.getAllFields(); // Wrap in the appropriate namespace. llvm::SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); for (auto ns : namespaces) os << "namespace " << ns << " {\n"; // Emit the struct class definition emitStructClass(structDef, structName, fields, description, os); // Close the declared namespace. for (auto ns : namespaces) os << "} // namespace " << ns << "\n"; } static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("Struct Utility Declarations", os); auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr"); for (const auto *def : defs) { emitStructDecl(*def, os); } return false; } static void emitFactoryDef(llvm::StringRef structName, llvm::ArrayRef fields, raw_ostream &os) { const char *getInfoDecl = "{0} {0}::get(\n"; const char *getInfoDeclArg = " {0} {1},\n"; const char *getInfoDeclEnd = " ::mlir::MLIRContext* context) {"; os << llvm::formatv(getInfoDecl, structName); for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); auto storage = type.getStorageType(); os << llvm::formatv(getInfoDeclArg, storage, name); } os << getInfoDeclEnd; const char *fieldStart = R"( ::llvm::SmallVector<::mlir::NamedAttribute, {0}> fields; )"; os << llvm::formatv(fieldStart, fields.size()); const char *getFieldInfo = R"( assert({0}); auto {0}_id = ::mlir::Identifier::get("{0}", context); fields.emplace_back({0}_id, {0}); )"; const char *getFieldInfoOptional = R"( if ({0}) { auto {0}_id = ::mlir::Identifier::get("{0}", context); fields.emplace_back({0}_id, {0}); } )"; for (auto field : fields) { if (field.getType().isOptional() || field.getType().hasDefaultValue()) os << llvm::formatv(getFieldInfoOptional, field.getName()); else os << llvm::formatv(getFieldInfo, field.getName()); } const char *getEndInfo = R"( ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context); return dict.dyn_cast<{0}>(); } )"; os << llvm::formatv(getEndInfo, structName); } static void emitClassofDef(llvm::StringRef structName, llvm::ArrayRef fields, raw_ostream &os) { const char *classofInfo = R"( bool {0}::classof(::mlir::Attribute attr))"; const char *classofInfoHeader = R"( if (!attr) return false; auto derived = attr.dyn_cast<::mlir::DictionaryAttr>(); if (!derived) return false; int num_absent_attrs = 0; )"; os << llvm::formatv(classofInfo, structName) << " {"; os << llvm::formatv(classofInfoHeader); FmtContext fctx; const char *classofArgInfo = R"( auto {0} = derived.get("{0}"); if (!{0} || !({1})) return false; )"; const char *classofArgInfoOptional = R"( auto {0} = derived.get("{0}"); if (!{0}) ++num_absent_attrs; else if (!({1})) return false; )"; for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); std::string condition = std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name))); if (type.isOptional() || type.hasDefaultValue()) os << llvm::formatv(classofArgInfoOptional, name, condition); else os << llvm::formatv(classofArgInfo, name, condition); } const char *classofEndInfo = R"( return derived.size() + num_absent_attrs == {0}; } )"; os << llvm::formatv(classofEndInfo, fields.size()); } static void emitAccessorDef(llvm::StringRef structName, llvm::ArrayRef fields, raw_ostream &os) { const char *fieldInfo = R"( {0} {2}::{1}() const { auto derived = this->cast<::mlir::DictionaryAttr>(); auto {1} = derived.get("{1}"); assert({1} && "attribute not found."); assert({1}.isa<{0}>() && "incorrect Attribute type found."); return {1}.cast<{0}>(); } )"; const char *fieldInfoOptional = R"( {0} {2}::{1}() const { auto derived = this->cast<::mlir::DictionaryAttr>(); auto {1} = derived.get("{1}"); if (!{1}) return nullptr; assert({1}.isa<{0}>() && "incorrect Attribute type found."); return {1}.cast<{0}>(); } )"; const char *fieldInfoDefaultValued = R"( {0} {2}::{1}() const { auto derived = this->cast<::mlir::DictionaryAttr>(); auto {1} = derived.get("{1}"); if (!{1}) { ::mlir::Builder builder(getContext()); return {3}; } assert({1}.isa<{0}>() && "incorrect Attribute type found."); return {1}.cast<{0}>(); } )"; FmtContext fmtCtx; fmtCtx.withBuilder("builder"); for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); auto storage = type.getStorageType(); if (type.isOptional()) { os << llvm::formatv(fieldInfoOptional, storage, name, structName); } else if (type.hasDefaultValue()) { std::string defaultValue = tgfmt(type.getConstBuilderTemplate(), &fmtCtx, type.getDefaultValue()); os << llvm::formatv(fieldInfoDefaultValued, storage, name, structName, defaultValue); } else { os << llvm::formatv(fieldInfo, storage, name, structName); } } } static void emitStructDef(const Record &structDef, raw_ostream &os) { StructAttr structAttr(&structDef); StringRef cppNamespace = structAttr.getCppNamespace(); StringRef structName = structAttr.getStructClassName(); mlir::tblgen::FmtContext ctx; auto fields = structAttr.getAllFields(); llvm::SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); for (auto ns : namespaces) os << "namespace " << ns << " {\n"; emitFactoryDef(structName, fields, os); emitClassofDef(structName, fields, os); emitAccessorDef(structName, fields, os); for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; } static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("Struct Utility Definitions", os); auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr"); for (const auto *def : defs) emitStructDef(*def, os); return false; } // Registers the struct utility generator to mlir-tblgen. static mlir::GenRegistration genStructDecls("gen-struct-attr-decls", "Generate struct utility declarations", [](const RecordKeeper &records, raw_ostream &os) { return emitStructDecls(records, os); }); // Registers the struct utility generator to mlir-tblgen. static mlir::GenRegistration genStructDefs("gen-struct-attr-defs", "Generate struct utility definitions", [](const RecordKeeper &records, raw_ostream &os) { return emitStructDefs(records, os); });