You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
203 lines
7.2 KiB
203 lines
7.2 KiB
//===- TosaTestPasses.cpp -------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Test passes to exercise TOSA helper functions.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Tosa/IR//TosaOps.h"
|
|
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
|
|
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#define PASS_NAME "tosa-test-quant-utils"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tosa;
|
|
|
|
// This transformation converts quantized uint8 to quantized int8. The
|
|
// construction of the new type invokes buildQTypeFromMinMax. Extracted from
|
|
// TOSA legalization infrastructure.
|
|
struct ConvertTosaNegateOp : public RewritePattern {
|
|
explicit ConvertTosaNegateOp(MLIRContext *context)
|
|
: RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {}
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
LogicalResult
|
|
ConvertTosaNegateOp::matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const {
|
|
|
|
auto tosaNegateOp = cast<tosa::NegateOp>(op);
|
|
|
|
auto inputType =
|
|
tosaNegateOp.input1().getType().dyn_cast<mlir::RankedTensorType>();
|
|
// skip if input is not ranked tensor type
|
|
if (!inputType)
|
|
return failure();
|
|
|
|
// skip if it's not ranked tensor type.
|
|
auto outputType =
|
|
tosaNegateOp.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
|
if (!outputType)
|
|
return failure();
|
|
|
|
// skip if output is not per-tensor quantized type.
|
|
auto outputElementType =
|
|
outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
|
|
if (!outputElementType)
|
|
return failure();
|
|
|
|
// skip if output is not uint8.
|
|
if (outputElementType.isSigned() ||
|
|
outputElementType.getStorageTypeIntegralWidth() != 8)
|
|
return failure();
|
|
|
|
double typeRangeMin = double(outputElementType.getStorageTypeMin() -
|
|
outputElementType.getZeroPoint()) *
|
|
outputElementType.getScale();
|
|
double typeRangeMax = double(outputElementType.getStorageTypeMax() -
|
|
outputElementType.getZeroPoint()) *
|
|
outputElementType.getScale();
|
|
bool narrow_range = outputElementType.getStorageTypeMin() == 1 ? true : false;
|
|
|
|
auto dstQConstType = RankedTensorType::get(
|
|
outputType.getShape(),
|
|
buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(),
|
|
rewriter.getF64FloatAttr(typeRangeMin),
|
|
rewriter.getF64FloatAttr(typeRangeMax),
|
|
rewriter.getI32IntegerAttr(
|
|
outputElementType.getStorageTypeIntegralWidth()),
|
|
0, true /* signed */,
|
|
rewriter.getBoolAttr(narrow_range)));
|
|
|
|
ElementsAttr inputElems;
|
|
if (!matchPattern(tosaNegateOp.input1(), m_Constant(&inputElems)))
|
|
return failure();
|
|
|
|
auto newConstOp =
|
|
rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems);
|
|
auto newNegateOp = rewriter.create<tosa::NegateOp>(
|
|
op->getLoc(), dstQConstType, newConstOp.getResult());
|
|
|
|
rewriter.replaceOp(op, {newNegateOp.getResult()});
|
|
return success();
|
|
}
|
|
|
|
// This transformation modifies the quantized output of a test conv2d input and
|
|
// appends a TOSA rescale after it. The rescale op requires the invocation of
|
|
// computeMultiplierAndShift. From TOSA legalization infrastructure.
|
|
struct ConvertTosaConv2DOp : public RewritePattern {
|
|
explicit ConvertTosaConv2DOp(MLIRContext *context)
|
|
: RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {}
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
LogicalResult
|
|
ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const {
|
|
|
|
auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);
|
|
|
|
auto inputType =
|
|
tosaConv2DOp.input().getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
// skip if input is not ranked tensor type
|
|
if (!inputType)
|
|
return failure();
|
|
|
|
auto weightType =
|
|
tosaConv2DOp.weight().getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
// skip if wt is not ranked tensor type
|
|
if (!weightType)
|
|
return failure();
|
|
|
|
// skip if it's not ranked tensor type.
|
|
auto outputType =
|
|
tosaConv2DOp.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
|
if (!outputType)
|
|
return failure();
|
|
|
|
auto inputQType =
|
|
inputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
|
|
auto weightQType =
|
|
weightType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
|
|
auto outputQType =
|
|
outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
|
|
|
|
// Works on quantized type only.
|
|
if (!(inputQType && weightQType && outputQType))
|
|
return failure();
|
|
|
|
auto newTosaConv2DOpType =
|
|
RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));
|
|
|
|
auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
|
|
op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.input(),
|
|
tosaConv2DOp.weight(), tosaConv2DOp.bias(), tosaConv2DOp.pad(),
|
|
tosaConv2DOp.stride(), tosaConv2DOp.dilation());
|
|
|
|
// Create rescale to quantized type
|
|
double inputScale = inputQType.getScale();
|
|
double weightScale = weightQType.getScale();
|
|
double outputScale = outputQType.getScale();
|
|
int64_t outputZp = outputQType.getZeroPoint();
|
|
|
|
double opTensorScale = (inputScale * weightScale) / outputScale;
|
|
|
|
int32_t multiplier;
|
|
int32_t shift;
|
|
|
|
// Obtain the quantized scale = multiplier and shift.
|
|
computeMultiplierAndShift(opTensorScale, multiplier, shift, 32);
|
|
|
|
auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
|
|
op->getLoc(), outputType, newTosaConv2DOp.getResult(),
|
|
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
|
|
rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
|
|
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
|
|
rewriter.getBoolAttr(false));
|
|
|
|
rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct TosaTestQuantUtilAPI
|
|
: public PassWrapper<TosaTestQuantUtilAPI, FunctionPass> {
|
|
void runOnFunction() override;
|
|
};
|
|
|
|
void TosaTestQuantUtilAPI::runOnFunction() {
|
|
OwningRewritePatternList patterns;
|
|
auto *ctx = &getContext();
|
|
auto func = getFunction();
|
|
|
|
patterns.insert<ConvertTosaNegateOp>(ctx);
|
|
patterns.insert<ConvertTosaConv2DOp>(ctx);
|
|
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
namespace mlir {
|
|
void registerTosaTestQuantUtilAPIPass() {
|
|
PassRegistration<TosaTestQuantUtilAPI>(
|
|
PASS_NAME, "TOSA Test: Exercise the APIs in QuantUtils.cpp.");
|
|
}
|
|
} // namespace mlir
|