You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
297 lines
12 KiB
297 lines
12 KiB
//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// \file
|
|
// This file implements the TOSA Specification:
|
|
// https://developer.mlplatform.org/w/tosa/
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
#include "mlir/Transforms/InliningUtils.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tosa;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Tosa dialect structs and interface includes.
|
|
//===----------------------------------------------------------------------===//
|
|
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
|
|
#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
|
|
|
|
namespace {
|
|
//===----------------------------------------------------------------------===//
|
|
// Dialect Function Inliner Interface.
|
|
//===----------------------------------------------------------------------===//
|
|
struct TosaInlinerInterface : public DialectInlinerInterface {
|
|
using DialectInlinerInterface::DialectInlinerInterface;
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Analysis Hooks.
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// All operations can be inlined by default.
|
|
bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
|
|
BlockAndValueMapping &map) const final {
|
|
return true;
|
|
}
|
|
|
|
/// All regions with If and While parent operators can be inlined.
|
|
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
|
BlockAndValueMapping &map) const final {
|
|
return (isa<tosa::IfOp>(dest->getParentOp()) ||
|
|
isa<tosa::WhileOp>(dest->getParentOp()));
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TOSA control flow support.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns the while loop body.
|
|
Region &tosa::WhileOp::getLoopBody() { return body(); }
|
|
|
|
bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
|
|
return !body().isAncestor(value.getParentRegion());
|
|
}
|
|
|
|
LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
|
|
if (ops.empty())
|
|
return success();
|
|
|
|
Operation *tosaWhileOp = this->getOperation();
|
|
for (auto *op : ops)
|
|
op->moveBefore(tosaWhileOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Tosa dialect initialization.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TosaDialect::initialize() {
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
|
|
>();
|
|
addInterfaces<TosaInlinerInterface>();
|
|
}
|
|
|
|
Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
|
Type type, Location loc) {
|
|
// Tosa dialect constants only support ElementsAttr unlike standard dialect
|
|
// constant which supports all attributes.
|
|
if (value.isa<ElementsAttr>())
|
|
return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
|
|
return nullptr;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operator Folders.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.empty() && "constant has no operands");
|
|
return valueAttr();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TOSA Operator Verifiers.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename T>
|
|
static LogicalResult verifyConvOp(T op) {
|
|
// All TOSA conv ops have an input() and weight().
|
|
auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
|
|
auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
|
|
|
|
// Must be ranked tensor types
|
|
if (!inputType || !weightType)
|
|
return failure();
|
|
|
|
auto inputQType =
|
|
inputType.getElementType().template isa<mlir::quant::QuantizedType>();
|
|
auto weightQType =
|
|
weightType.getElementType().template isa<mlir::quant::QuantizedType>();
|
|
|
|
// Either both must be quantized or both unquantized.
|
|
if (inputQType != weightQType)
|
|
return failure();
|
|
|
|
// Quantized type must have constructed the quantizationattr, and unquantized
|
|
// types should not have a quantizationattr.
|
|
if ((inputQType && !op.quantization_info()) ||
|
|
(!inputQType && op.quantization_info()))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TOSA Operator Quantization Builders.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This builder is called on all convolution operators except TransposeConv,
|
|
/// which has specialized output shape semantics. The builder also defines the
|
|
/// bitwidth of the output given the bit width of the input & weight content.
|
|
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
|
Type outputType, Value input, Value weight,
|
|
Value bias, ArrayAttr pad,
|
|
ArrayAttr stride, ArrayAttr dilation) {
|
|
|
|
result.addOperands({input, weight, bias});
|
|
result.addAttribute("pad", pad);
|
|
result.addAttribute("stride", stride);
|
|
result.addAttribute("dilation", dilation);
|
|
|
|
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
|
|
if (quantAttr) {
|
|
result.addAttribute("quantization_info", quantAttr);
|
|
result.addTypes(
|
|
buildConvOpResultTypeInfo(builder, outputType, input, weight));
|
|
} else {
|
|
result.addTypes(outputType);
|
|
}
|
|
}
|
|
|
|
/// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
|
|
static void
|
|
buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
|
Type outputType, Value input, Value weight,
|
|
Value bias, ArrayAttr outpad, ArrayAttr stride,
|
|
ArrayAttr dilation, ArrayAttr outputShape) {
|
|
result.addOperands({input, weight, bias});
|
|
result.addAttribute("out_pad", outpad);
|
|
result.addAttribute("stride", stride);
|
|
result.addAttribute("dilation", dilation);
|
|
result.addAttribute("out_shape", outputShape);
|
|
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
|
|
|
|
if (quantAttr) {
|
|
result.addAttribute("quantization_info", quantAttr);
|
|
result.addTypes(
|
|
buildConvOpResultTypeInfo(builder, outputType, input, weight));
|
|
} else {
|
|
result.addTypes(outputType);
|
|
}
|
|
}
|
|
|
|
/// The tosa.fully_connected op has its own builder as it does not have
|
|
/// strides/dilation/padding.
|
|
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
|
Type outputType, Value input, Value weight,
|
|
Value bias) {
|
|
|
|
result.addOperands({input, weight, bias});
|
|
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
|
|
if (quantAttr) {
|
|
result.addAttribute("quantization_info", quantAttr);
|
|
result.addTypes(
|
|
buildConvOpResultTypeInfo(builder, outputType, input, weight));
|
|
} else {
|
|
result.addTypes(outputType);
|
|
}
|
|
}
|
|
|
|
/// The tosa.matmul op is also intended to be generated where a fully_connected
|
|
/// op must be constructed where the weight is not a constant. In this case,
|
|
/// the fully_connected op must be expressed using matmul.
|
|
/// TODO: Add link to the leglization document explaining this.
|
|
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
|
|
OperationState &result, Type outputType,
|
|
Value a, Value b) {
|
|
result.addOperands({a, b});
|
|
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
|
|
|
|
if (quantAttr) {
|
|
result.addAttribute("quantization_info", quantAttr);
|
|
|
|
auto inputType = a.getType().dyn_cast<RankedTensorType>();
|
|
assert(inputType && "Input must be a ranked tensor type!");
|
|
|
|
auto inputQType = inputType.getElementType()
|
|
.dyn_cast<mlir::quant::UniformQuantizedType>();
|
|
assert(inputQType && "Tensor must have quantized datatype!");
|
|
|
|
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
|
|
|
|
auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
|
|
assert(outputShapedType && "Output must be a ranked tensor type");
|
|
|
|
auto outputShape = outputShapedType.getShape();
|
|
|
|
IntegerType accElementType;
|
|
if (inputBits == 16)
|
|
accElementType = builder.getIntegerType(48);
|
|
else
|
|
accElementType = builder.getI32Type();
|
|
auto accType = RankedTensorType::get(outputShape, accElementType);
|
|
result.addTypes(accType);
|
|
} else {
|
|
result.addTypes(outputType);
|
|
}
|
|
}
|
|
|
|
/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
|
|
/// but avg_pool operator has its own builder as it has additional parameters
|
|
/// not part of the unary ops.
|
|
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
|
|
OperationState &result,
|
|
Type outputType, Value input,
|
|
ArrayAttr kernel, ArrayAttr stride,
|
|
ArrayAttr pad) {
|
|
result.addOperands(input);
|
|
result.addAttribute("kernel", kernel);
|
|
result.addAttribute("stride", stride);
|
|
result.addAttribute("pad", pad);
|
|
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
|
|
if (quantAttr)
|
|
result.addAttribute("quantization_info", quantAttr);
|
|
result.types.push_back(outputType);
|
|
}
|
|
|
|
/// This builder is called on single-parameter unary operators that have scale
|
|
/// relationship between their input and output, expressed by the
|
|
/// UnaryOpQuantizationAttr.
|
|
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
|
|
OperationState &result, Type outputType,
|
|
Value input) {
|
|
result.addOperands(input);
|
|
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
|
|
if (quantAttr)
|
|
result.addAttribute("quantization_info", quantAttr);
|
|
result.types.push_back(outputType);
|
|
}
|
|
|
|
/// This builder is called on TOSA pad operator that needs to create its own
|
|
/// OptionalAttr quantization_attr parameter to scale the padding values
|
|
/// correctly.
|
|
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
|
Type outputType, Value input,
|
|
Value paddings) {
|
|
result.addOperands({input, paddings});
|
|
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
|
|
if (quantAttr)
|
|
result.addAttribute("quantization_info", quantAttr);
|
|
result.types.push_back(outputType);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TOSA Operator Definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
|