You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

297 lines
12 KiB

//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file
// This file implements the TOSA Specification:
// https://developer.mlplatform.org/w/tosa/
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
using namespace mlir;
using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// Tosa dialect structs and interface includes.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
namespace {
//===----------------------------------------------------------------------===//
// Dialect Function Inliner Interface.
//===----------------------------------------------------------------------===//
struct TosaInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks.
//===--------------------------------------------------------------------===//
/// All operations can be inlined by default.
bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
BlockAndValueMapping &map) const final {
return true;
}
/// All regions with If and While parent operators can be inlined.
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &map) const final {
return (isa<tosa::IfOp>(dest->getParentOp()) ||
isa<tosa::WhileOp>(dest->getParentOp()));
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// TOSA control flow support.
//===----------------------------------------------------------------------===//
/// Returns the while loop body.
Region &tosa::WhileOp::getLoopBody() { return body(); }
bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
return !body().isAncestor(value.getParentRegion());
}
LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
if (ops.empty())
return success();
Operation *tosaWhileOp = this->getOperation();
for (auto *op : ops)
op->moveBefore(tosaWhileOp);
return success();
}
//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
//===----------------------------------------------------------------------===//
void TosaDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
>();
addInterfaces<TosaInlinerInterface>();
}
Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
// Tosa dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
if (value.isa<ElementsAttr>())
return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
return nullptr;
}
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
return valueAttr();
}
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
template <typename T>
static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
// Must be ranked tensor types
if (!inputType || !weightType)
return failure();
auto inputQType =
inputType.getElementType().template isa<mlir::quant::QuantizedType>();
auto weightQType =
weightType.getElementType().template isa<mlir::quant::QuantizedType>();
// Either both must be quantized or both unquantized.
if (inputQType != weightQType)
return failure();
// Quantized type must have constructed the quantizationattr, and unquantized
// types should not have a quantizationattr.
if ((inputQType && !op.quantization_info()) ||
(!inputQType && op.quantization_info()))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// TOSA Operator Quantization Builders.
//===----------------------------------------------------------------------===//
/// This builder is called on all convolution operators except TransposeConv,
/// which has specialized output shape semantics. The builder also defines the
/// bitwidth of the output given the bit width of the input & weight content.
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias, ArrayAttr pad,
ArrayAttr stride, ArrayAttr dilation) {
result.addOperands({input, weight, bias});
result.addAttribute("pad", pad);
result.addAttribute("stride", stride);
result.addAttribute("dilation", dilation);
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
result.addAttribute("quantization_info", quantAttr);
result.addTypes(
buildConvOpResultTypeInfo(builder, outputType, input, weight));
} else {
result.addTypes(outputType);
}
}
/// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static void
buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias, ArrayAttr outpad, ArrayAttr stride,
ArrayAttr dilation, ArrayAttr outputShape) {
result.addOperands({input, weight, bias});
result.addAttribute("out_pad", outpad);
result.addAttribute("stride", stride);
result.addAttribute("dilation", dilation);
result.addAttribute("out_shape", outputShape);
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
result.addAttribute("quantization_info", quantAttr);
result.addTypes(
buildConvOpResultTypeInfo(builder, outputType, input, weight));
} else {
result.addTypes(outputType);
}
}
/// The tosa.fully_connected op has its own builder as it does not have
/// strides/dilation/padding.
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias) {
result.addOperands({input, weight, bias});
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
result.addAttribute("quantization_info", quantAttr);
result.addTypes(
buildConvOpResultTypeInfo(builder, outputType, input, weight));
} else {
result.addTypes(outputType);
}
}
/// The tosa.matmul op is also intended to be generated where a fully_connected
/// op must be constructed where the weight is not a constant. In this case,
/// the fully_connected op must be expressed using matmul.
/// TODO: Add link to the leglization document explaining this.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Value a, Value b) {
result.addOperands({a, b});
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
if (quantAttr) {
result.addAttribute("quantization_info", quantAttr);
auto inputType = a.getType().dyn_cast<RankedTensorType>();
assert(inputType && "Input must be a ranked tensor type!");
auto inputQType = inputType.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
assert(inputQType && "Tensor must have quantized datatype!");
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
assert(outputShapedType && "Output must be a ranked tensor type");
auto outputShape = outputShapedType.getShape();
IntegerType accElementType;
if (inputBits == 16)
accElementType = builder.getIntegerType(48);
else
accElementType = builder.getI32Type();
auto accType = RankedTensorType::get(outputShape, accElementType);
result.addTypes(accType);
} else {
result.addTypes(outputType);
}
}
/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
/// but avg_pool operator has its own builder as it has additional parameters
/// not part of the unary ops.
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
OperationState &result,
Type outputType, Value input,
ArrayAttr kernel, ArrayAttr stride,
ArrayAttr pad) {
result.addOperands(input);
result.addAttribute("kernel", kernel);
result.addAttribute("stride", stride);
result.addAttribute("pad", pad);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
if (quantAttr)
result.addAttribute("quantization_info", quantAttr);
result.types.push_back(outputType);
}
/// This builder is called on single-parameter unary operators that have scale
/// relationship between their input and output, expressed by the
/// UnaryOpQuantizationAttr.
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Value input) {
result.addOperands(input);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
if (quantAttr)
result.addAttribute("quantization_info", quantAttr);
result.types.push_back(outputType);
}
/// This builder is called on TOSA pad operator that needs to create its own
/// OptionalAttr quantization_attr parameter to scale the padding values
/// correctly.
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input,
Value paddings) {
result.addOperands({input, paddings});
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
if (quantAttr)
result.addAttribute("quantization_info", quantAttr);
result.types.push_back(outputType);
}
//===----------------------------------------------------------------------===//
// TOSA Operator Definitions.
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"