You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
298 lines
11 KiB
298 lines
11 KiB
//===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file uses tablegen definitions of the LLVM IR Dialect operations to
|
|
// generate the code building the LLVM IR from it.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/TableGen/Attribute.h"
|
|
#include "mlir/TableGen/GenInfo.h"
|
|
#include "mlir/TableGen/Operator.h"
|
|
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
#include "llvm/TableGen/TableGenBackend.h"
|
|
|
|
using namespace llvm;
|
|
using namespace mlir;
|
|
|
|
static bool emitError(const Twine &message) {
|
|
llvm::errs() << message << "\n";
|
|
return false;
|
|
}
|
|
|
|
namespace {
|
|
// Helper structure to return a position of the substring in a string.
|
|
struct StringLoc {
|
|
size_t pos;
|
|
size_t length;
|
|
|
|
// Take a substring identified by this location in the given string.
|
|
StringRef in(StringRef str) const { return str.substr(pos, length); }
|
|
|
|
// A location is invalid if its position is outside the string.
|
|
explicit operator bool() { return pos != std::string::npos; }
|
|
};
|
|
} // namespace
|
|
|
|
// Find the next TableGen variable in the given pattern. These variables start
|
|
// with a `$` character and can contain alphanumeric characters or underscores.
|
|
// Return the position of the variable in the pattern and its length, including
|
|
// the `$` character. The escape syntax `$$` is also detected and returned.
|
|
static StringLoc findNextVariable(StringRef str) {
|
|
size_t startPos = str.find('$');
|
|
if (startPos == std::string::npos)
|
|
return {startPos, 0};
|
|
|
|
// If we see "$$", return immediately.
|
|
if (startPos != str.size() - 1 && str[startPos + 1] == '$')
|
|
return {startPos, 2};
|
|
|
|
// Otherwise, the symbol spans until the first character that is not
|
|
// alphanumeric or '_'.
|
|
size_t endPos = str.find_if_not([](char c) { return isAlnum(c) || c == '_'; },
|
|
startPos + 1);
|
|
if (endPos == std::string::npos)
|
|
endPos = str.size();
|
|
|
|
return {startPos, endPos - startPos};
|
|
}
|
|
|
|
// Check if `name` is the name of the variadic operand of `op`. The variadic
|
|
// operand can only appear at the last position in the list of operands.
|
|
static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
|
|
unsigned numOperands = op.getNumOperands();
|
|
if (numOperands == 0)
|
|
return false;
|
|
const auto &operand = op.getOperand(numOperands - 1);
|
|
return operand.isVariableLength() && operand.name == name;
|
|
}
|
|
|
|
// Check if `result` is a known name of a result of `op`.
|
|
static bool isResultName(const tblgen::Operator &op, StringRef name) {
|
|
for (int i = 0, e = op.getNumResults(); i < e; ++i)
|
|
if (op.getResultName(i) == name)
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
// Check if `name` is a known name of an attribute of `op`.
|
|
static bool isAttributeName(const tblgen::Operator &op, StringRef name) {
|
|
return llvm::any_of(
|
|
op.getAttributes(),
|
|
[name](const tblgen::NamedAttribute &attr) { return attr.name == name; });
|
|
}
|
|
|
|
// Check if `name` is a known name of an operand of `op`.
|
|
static bool isOperandName(const tblgen::Operator &op, StringRef name) {
|
|
for (int i = 0, e = op.getNumOperands(); i < e; ++i)
|
|
if (op.getOperand(i).name == name)
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
// Emit to `os` the operator-name driven check and the call to LLVM IRBuilder
|
|
// for one definition of an LLVM IR Dialect operation. Return true on success.
|
|
static bool emitOneBuilder(const Record &record, raw_ostream &os) {
|
|
auto op = tblgen::Operator(record);
|
|
|
|
if (!record.getValue("llvmBuilder"))
|
|
return emitError("no 'llvmBuilder' field for op " + op.getOperationName());
|
|
|
|
// Return early if there is no builder specified.
|
|
auto builderStrRef = record.getValueAsString("llvmBuilder");
|
|
if (builderStrRef.empty())
|
|
return true;
|
|
|
|
// Progressively create the builder string by replacing $-variables with
|
|
// value lookups. Keep only the not-yet-traversed part of the builder pattern
|
|
// to avoid re-traversing the string multiple times.
|
|
std::string builder;
|
|
llvm::raw_string_ostream bs(builder);
|
|
while (auto loc = findNextVariable(builderStrRef)) {
|
|
auto name = loc.in(builderStrRef).drop_front();
|
|
// First, insert the non-matched part as is.
|
|
bs << builderStrRef.substr(0, loc.pos);
|
|
// Then, rewrite the name based on its kind.
|
|
bool isVariadicOperand = isVariadicOperandName(op, name);
|
|
if (isOperandName(op, name)) {
|
|
auto result = isVariadicOperand
|
|
? formatv("lookupValues(op.{0}())", name)
|
|
: formatv("valueMapping.lookup(op.{0}())", name);
|
|
bs << result;
|
|
} else if (isAttributeName(op, name)) {
|
|
bs << formatv("op.{0}()", name);
|
|
} else if (isResultName(op, name)) {
|
|
bs << formatv("valueMapping[op.{0}()]", name);
|
|
} else if (name == "_resultType") {
|
|
bs << "convertType(op.getResult().getType().cast<LLVM::LLVMType>())";
|
|
} else if (name == "_hasResult") {
|
|
bs << "opInst.getNumResults() == 1";
|
|
} else if (name == "_location") {
|
|
bs << "opInst.getLoc()";
|
|
} else if (name == "_numOperands") {
|
|
bs << "opInst.getNumOperands()";
|
|
} else if (name == "$") {
|
|
bs << '$';
|
|
} else {
|
|
return emitError(name + " is neither an argument nor a result of " +
|
|
op.getOperationName());
|
|
}
|
|
// Finally, only keep the untraversed part of the string.
|
|
builderStrRef = builderStrRef.substr(loc.pos + loc.length);
|
|
}
|
|
|
|
// Output the check and the rewritten builder string.
|
|
os << "if (auto op = dyn_cast<" << op.getQualCppClassName()
|
|
<< ">(opInst)) {\n";
|
|
os << bs.str() << builderStrRef << "\n";
|
|
os << " return success();\n";
|
|
os << "}\n";
|
|
|
|
return true;
|
|
}
|
|
|
|
// Emit all builders. Returns false on success because of the generator
|
|
// registration requirements.
|
|
static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
|
|
if (!emitOneBuilder(*def, os))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
namespace {
|
|
// Wrapper class around a Tablegen definition of an LLVM enum attribute case.
|
|
class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
|
|
public:
|
|
using tblgen::EnumAttrCase::EnumAttrCase;
|
|
|
|
// Constructs a case from a non LLVM-specific enum attribute case.
|
|
explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
|
|
: tblgen::EnumAttrCase(&other.getDef()) {}
|
|
|
|
// Returns the C++ enumerant for the LLVM API.
|
|
StringRef getLLVMEnumerant() const {
|
|
return def->getValueAsString("llvmEnumerant");
|
|
}
|
|
};
|
|
|
|
// Wraper class around a Tablegen definition of an LLVM enum attribute.
|
|
class LLVMEnumAttr : public tblgen::EnumAttr {
|
|
public:
|
|
using tblgen::EnumAttr::EnumAttr;
|
|
|
|
// Returns the C++ enum name for the LLVM API.
|
|
StringRef getLLVMClassName() const {
|
|
return def->getValueAsString("llvmClassName");
|
|
}
|
|
|
|
// Returns all associated cases viewed as LLVM-specific enum cases.
|
|
std::vector<LLVMEnumAttrCase> getAllCases() const {
|
|
std::vector<LLVMEnumAttrCase> cases;
|
|
|
|
for (auto &c : tblgen::EnumAttr::getAllCases())
|
|
cases.push_back(LLVMEnumAttrCase(c));
|
|
|
|
return cases;
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
|
|
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
|
|
// (Enum) to the corresponding LLVM API enumerant
|
|
static void emitOneEnumToConversion(const llvm::Record *record,
|
|
raw_ostream &os) {
|
|
LLVMEnumAttr enumAttr(record);
|
|
StringRef llvmClass = enumAttr.getLLVMClassName();
|
|
StringRef cppClassName = enumAttr.getEnumClassName();
|
|
StringRef cppNamespace = enumAttr.getCppNamespace();
|
|
|
|
// Emit the function converting the enum attribute to its LLVM counterpart.
|
|
os << formatv("static {0} convert{1}ToLLVM({2}::{1} value) {{\n", llvmClass,
|
|
cppClassName, cppNamespace);
|
|
os << " switch (value) {\n";
|
|
|
|
for (const auto &enumerant : enumAttr.getAllCases()) {
|
|
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
|
|
StringRef cppEnumerant = enumerant.getSymbol();
|
|
os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
|
|
cppEnumerant);
|
|
os << formatv(" return {0}::{1};\n", llvmClass, llvmEnumerant);
|
|
}
|
|
|
|
os << " }\n";
|
|
os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
|
|
enumAttr.getEnumClassName());
|
|
os << "}\n\n";
|
|
}
|
|
|
|
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
|
|
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
|
|
// LLVM dialect enum attribute (Enum).
|
|
static void emitOneEnumFromConversion(const llvm::Record *record,
|
|
raw_ostream &os) {
|
|
LLVMEnumAttr enumAttr(record);
|
|
StringRef llvmClass = enumAttr.getLLVMClassName();
|
|
StringRef cppClassName = enumAttr.getEnumClassName();
|
|
StringRef cppNamespace = enumAttr.getCppNamespace();
|
|
|
|
// Emit the function converting the enum attribute from its LLVM counterpart.
|
|
os << formatv("inline {0}::{1} convert{1}FromLLVM({2} value) {{\n",
|
|
cppNamespace, cppClassName, llvmClass);
|
|
os << " switch (value) {\n";
|
|
|
|
for (const auto &enumerant : enumAttr.getAllCases()) {
|
|
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
|
|
StringRef cppEnumerant = enumerant.getSymbol();
|
|
os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant);
|
|
os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
|
|
cppEnumerant);
|
|
}
|
|
|
|
os << " }\n";
|
|
os << formatv(" llvm_unreachable(\"unknown {0} type\");",
|
|
enumAttr.getLLVMClassName());
|
|
os << "}\n\n";
|
|
}
|
|
|
|
// Emits conversion functions between MLIR enum attribute case and corresponding
|
|
// LLVM API enumerants for all registered LLVM dialect enum attributes.
|
|
template <bool ConvertTo>
|
|
static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr"))
|
|
if (ConvertTo)
|
|
emitOneEnumToConversion(def, os);
|
|
else
|
|
emitOneEnumFromConversion(def, os);
|
|
|
|
return false;
|
|
}
|
|
|
|
static mlir::GenRegistration
|
|
genLLVMIRConversions("gen-llvmir-conversions",
|
|
"Generate LLVM IR conversions", emitBuilders);
|
|
|
|
static mlir::GenRegistration
|
|
genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",
|
|
"Generate conversions of EnumAttrs to LLVM IR",
|
|
emitEnumConversionDefs</*ConvertTo=*/true>);
|
|
|
|
static mlir::GenRegistration
|
|
genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions",
|
|
"Generate conversions of EnumAttrs from LLVM IR",
|
|
emitEnumConversionDefs</*ConvertTo=*/false>);
|