You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
328 lines
12 KiB
328 lines
12 KiB
7 months ago
|
# Chapter 6: Lowering to LLVM and CodeGeneration
|
||
|
|
||
|
[TOC]
|
||
|
|
||
|
In the [previous chapter](Ch-5.md), we introduced the
|
||
|
[dialect conversion](../../DialectConversion.md) framework and partially lowered
|
||
|
many of the `Toy` operations to affine loop nests for optimization. In this
|
||
|
chapter, we will finally lower to LLVM for code generation.
|
||
|
|
||
|
## Lowering to LLVM
|
||
|
|
||
|
For this lowering, we will again use the dialect conversion framework to perform
|
||
|
the heavy lifting. However, this time, we will be performing a full conversion
|
||
|
to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already
|
||
|
lowered all but one of the `toy` operations, with the last being `toy.print`.
|
||
|
Before going over the conversion to LLVM, let's lower the `toy.print` operation.
|
||
|
We will lower this operation to a non-affine loop nest that invokes `printf` for
|
||
|
each element. Note that, because the dialect conversion framework supports
|
||
|
[transitive lowering](../../../getting_started/Glossary.md#transitive-lowering), we don't need to
|
||
|
directly emit operations in the LLVM dialect. By transitive lowering, we mean
|
||
|
that the conversion framework may apply multiple patterns to fully legalize an
|
||
|
operation. In this example, we are generating a structured loop nest instead of
|
||
|
the branch-form in the LLVM dialect. As long as we then have a lowering from the
|
||
|
loop operations to LLVM, the lowering will still succeed.
|
||
|
|
||
|
During lowering we can get, or build, the declaration for printf as so:
|
||
|
|
||
|
```c++
|
||
|
/// Return a symbol reference to the printf function, inserting it into the
|
||
|
/// module if necessary.
|
||
|
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
|
||
|
ModuleOp module,
|
||
|
LLVM::LLVMDialect *llvmDialect) {
|
||
|
auto *context = module.getContext();
|
||
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
|
||
|
return SymbolRefAttr::get("printf", context);
|
||
|
|
||
|
// Create a function declaration for printf, the signature is:
|
||
|
// * `i32 (i8*, ...)`
|
||
|
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
||
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||
|
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
|
||
|
/*isVarArg=*/true);
|
||
|
|
||
|
// Insert the printf function into the body of the parent module.
|
||
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||
|
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
|
||
|
return SymbolRefAttr::get("printf", context);
|
||
|
}
|
||
|
```
|
||
|
|
||
|
Now that the lowering for the printf operation has been defined, we can specify
|
||
|
the components necessary for the lowering. These are largely the same as the
|
||
|
components defined in the [previous chapter](Ch-5.md).
|
||
|
|
||
|
### Conversion Target
|
||
|
|
||
|
For this conversion, aside from the top-level module, we will be lowering
|
||
|
everything to the LLVM dialect.
|
||
|
|
||
|
```c++
|
||
|
mlir::ConversionTarget target(getContext());
|
||
|
target.addLegalDialect<mlir::LLVMDialect>();
|
||
|
target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
|
||
|
```
|
||
|
|
||
|
### Type Converter
|
||
|
|
||
|
This lowering will also transform the MemRef types which are currently being
|
||
|
operated on into a representation in LLVM. To perform this conversion, we use a
|
||
|
TypeConverter as part of the lowering. This converter specifies how one type
|
||
|
maps to another. This is necessary now that we are performing more complicated
|
||
|
lowerings involving block arguments. Given that we don't have any
|
||
|
Toy-dialect-specific types that need to be lowered, the default converter is
|
||
|
enough for our use case.
|
||
|
|
||
|
```c++
|
||
|
LLVMTypeConverter typeConverter(&getContext());
|
||
|
```
|
||
|
|
||
|
### Conversion Patterns
|
||
|
|
||
|
Now that the conversion target has been defined, we need to provide the patterns
|
||
|
used for lowering. At this point in the compilation process, we have a
|
||
|
combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and
|
||
|
`affine` dialects already provide the set of patterns needed to transform them
|
||
|
into LLVM dialect. These patterns allow for lowering the IR in multiple stages
|
||
|
by relying on [transitive lowering](../../../getting_started/Glossary.md#transitive-lowering).
|
||
|
|
||
|
```c++
|
||
|
mlir::OwningRewritePatternList patterns;
|
||
|
mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
|
||
|
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
|
||
|
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||
|
|
||
|
// The only remaining operation, to lower from the `toy` dialect, is the
|
||
|
// PrintOp.
|
||
|
patterns.insert<PrintOpLowering>(&getContext());
|
||
|
```
|
||
|
|
||
|
### Full Lowering
|
||
|
|
||
|
We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
|
||
|
that only legal operations will remain after the conversion.
|
||
|
|
||
|
```c++
|
||
|
mlir::ModuleOp module = getOperation();
|
||
|
if (mlir::failed(mlir::applyFullConversion(module, target, patterns)))
|
||
|
signalPassFailure();
|
||
|
```
|
||
|
|
||
|
Looking back at our current working example:
|
||
|
|
||
|
```mlir
|
||
|
func @main() {
|
||
|
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||
|
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||
|
%3 = toy.mul %2, %2 : tensor<3x2xf64>
|
||
|
toy.print %3 : tensor<3x2xf64>
|
||
|
toy.return
|
||
|
}
|
||
|
```
|
||
|
|
||
|
We can now lower down to the LLVM dialect, which produces the following code:
|
||
|
|
||
|
```mlir
|
||
|
llvm.func @free(!llvm<"i8*">)
|
||
|
llvm.func @printf(!llvm<"i8*">, ...) -> !llvm.i32
|
||
|
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
|
||
|
llvm.func @main() {
|
||
|
%0 = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double
|
||
|
%1 = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double
|
||
|
|
||
|
...
|
||
|
|
||
|
^bb16:
|
||
|
%221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
|
||
|
%222 = llvm.mlir.constant(0 : index) : !llvm.i64
|
||
|
%223 = llvm.mlir.constant(2 : index) : !llvm.i64
|
||
|
%224 = llvm.mul %214, %223 : !llvm.i64
|
||
|
%225 = llvm.add %222, %224 : !llvm.i64
|
||
|
%226 = llvm.mlir.constant(1 : index) : !llvm.i64
|
||
|
%227 = llvm.mul %219, %226 : !llvm.i64
|
||
|
%228 = llvm.add %225, %227 : !llvm.i64
|
||
|
%229 = llvm.getelementptr %221[%228] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*">
|
||
|
%230 = llvm.load %229 : !llvm<"double*">
|
||
|
%231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
|
||
|
%232 = llvm.add %219, %218 : !llvm.i64
|
||
|
llvm.br ^bb15(%232 : !llvm.i64)
|
||
|
|
||
|
...
|
||
|
|
||
|
^bb18:
|
||
|
%235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
|
||
|
%236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
|
||
|
llvm.call @free(%236) : (!llvm<"i8*">) -> ()
|
||
|
%237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
|
||
|
%238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
|
||
|
llvm.call @free(%238) : (!llvm<"i8*">) -> ()
|
||
|
%239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
|
||
|
%240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
|
||
|
llvm.call @free(%240) : (!llvm<"i8*">) -> ()
|
||
|
llvm.return
|
||
|
}
|
||
|
```
|
||
|
|
||
|
See [Conversion to the LLVM IR Dialect](../../ConversionToLLVMDialect.md) for
|
||
|
more in-depth details on lowering to the LLVM dialect.
|
||
|
|
||
|
## CodeGen: Getting Out of MLIR
|
||
|
|
||
|
At this point we are right at the cusp of code generation. We can generate code
|
||
|
in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to
|
||
|
run it.
|
||
|
|
||
|
### Emitting LLVM IR
|
||
|
|
||
|
Now that our module is comprised only of operations in the LLVM dialect, we can
|
||
|
export to LLVM IR. To do this programmatically, we can invoke the following
|
||
|
utility:
|
||
|
|
||
|
```c++
|
||
|
std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
|
||
|
if (!llvmModule)
|
||
|
/* ... an error was encountered ... */
|
||
|
```
|
||
|
|
||
|
Exporting our module to LLVM IR generates:
|
||
|
|
||
|
```llvm
|
||
|
define void @main() {
|
||
|
...
|
||
|
|
||
|
102:
|
||
|
%103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
|
||
|
%104 = mul i64 %96, 2
|
||
|
%105 = add i64 0, %104
|
||
|
%106 = mul i64 %100, 1
|
||
|
%107 = add i64 %105, %106
|
||
|
%108 = getelementptr double, double* %103, i64 %107
|
||
|
%109 = load double, double* %108
|
||
|
%110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
|
||
|
%111 = add i64 %100, 1
|
||
|
br label %99
|
||
|
|
||
|
...
|
||
|
|
||
|
115:
|
||
|
%116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
|
||
|
%117 = bitcast double* %116 to i8*
|
||
|
call void @free(i8* %117)
|
||
|
%118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
|
||
|
%119 = bitcast double* %118 to i8*
|
||
|
call void @free(i8* %119)
|
||
|
%120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
|
||
|
%121 = bitcast double* %120 to i8*
|
||
|
call void @free(i8* %121)
|
||
|
ret void
|
||
|
}
|
||
|
```
|
||
|
|
||
|
If we enable optimization on the generated LLVM IR, we can trim this down quite
|
||
|
a bit:
|
||
|
|
||
|
```llvm
|
||
|
define void @main()
|
||
|
%0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
|
||
|
%1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
|
||
|
%putchar = tail call i32 @putchar(i32 10)
|
||
|
%2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
|
||
|
%3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
|
||
|
%putchar.1 = tail call i32 @putchar(i32 10)
|
||
|
%4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
|
||
|
%5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
|
||
|
%putchar.2 = tail call i32 @putchar(i32 10)
|
||
|
ret void
|
||
|
}
|
||
|
```
|
||
|
|
||
|
The full code listing for dumping LLVM IR can be found in
|
||
|
`examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function:
|
||
|
|
||
|
```c++
|
||
|
|
||
|
int dumpLLVMIR(mlir::ModuleOp module) {
|
||
|
// Translate the module, that contains the LLVM dialect, to LLVM IR. Use a
|
||
|
// fresh LLVM IR context. (Note that LLVM is not thread-safe and any
|
||
|
// concurrent use of a context requires external locking.)
|
||
|
llvm::LLVMContext llvmContext;
|
||
|
auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
|
||
|
if (!llvmModule) {
|
||
|
llvm::errs() << "Failed to emit LLVM IR\n";
|
||
|
return -1;
|
||
|
}
|
||
|
|
||
|
// Initialize LLVM targets.
|
||
|
llvm::InitializeNativeTarget();
|
||
|
llvm::InitializeNativeTargetAsmPrinter();
|
||
|
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
|
||
|
|
||
|
/// Optionally run an optimization pipeline over the llvm module.
|
||
|
auto optPipeline = mlir::makeOptimizingTransformer(
|
||
|
/*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
|
||
|
/*targetMachine=*/nullptr);
|
||
|
if (auto err = optPipeline(llvmModule.get())) {
|
||
|
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
|
||
|
return -1;
|
||
|
}
|
||
|
llvm::errs() << *llvmModule << "\n";
|
||
|
return 0;
|
||
|
}
|
||
|
```
|
||
|
|
||
|
### Setting up a JIT
|
||
|
|
||
|
Setting up a JIT to run the module containing the LLVM dialect can be done using
|
||
|
the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around
|
||
|
LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up
|
||
|
the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function:
|
||
|
|
||
|
```c++
|
||
|
int runJit(mlir::ModuleOp module) {
|
||
|
// Initialize LLVM targets.
|
||
|
llvm::InitializeNativeTarget();
|
||
|
llvm::InitializeNativeTargetAsmPrinter();
|
||
|
|
||
|
// An optimization pipeline to use within the execution engine.
|
||
|
auto optPipeline = mlir::makeOptimizingTransformer(
|
||
|
/*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
|
||
|
/*targetMachine=*/nullptr);
|
||
|
|
||
|
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
|
||
|
// the module.
|
||
|
auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
|
||
|
assert(maybeEngine && "failed to construct an execution engine");
|
||
|
auto &engine = maybeEngine.get();
|
||
|
|
||
|
// Invoke the JIT-compiled function.
|
||
|
auto invocationResult = engine->invoke("main");
|
||
|
if (invocationResult) {
|
||
|
llvm::errs() << "JIT invocation failed\n";
|
||
|
return -1;
|
||
|
}
|
||
|
|
||
|
return 0;
|
||
|
}
|
||
|
```
|
||
|
|
||
|
You can play around with it from the build directory:
|
||
|
|
||
|
```shell
|
||
|
$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
|
||
|
1.000000 2.000000
|
||
|
3.000000 4.000000
|
||
|
```
|
||
|
|
||
|
You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and
|
||
|
`-emit=llvm` to compare the various levels of IR involved. Also try options like
|
||
|
[`--print-ir-after-all`](../../PassManagement.md#ir-printing) to track the
|
||
|
evolution of the IR throughout the pipeline.
|
||
|
|
||
|
The example code used throughout this section can be found in
|
||
|
test/Examples/Toy/Ch6/llvm-lowering.mlir.
|
||
|
|
||
|
So far, we have worked with primitive data types. In the
|
||
|
[next chapter](Ch-7.md), we will add a composite `struct` type.
|