You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
379 lines
16 KiB
379 lines
16 KiB
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <iostream>
|
|
#include <vector>
|
|
|
|
#include "NeuralNetworksWrapper.h"
|
|
#include "QuantizedLSTM.h"
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
namespace wrapper {
|
|
|
|
namespace {
|
|
|
|
struct OperandTypeParams {
|
|
Type type;
|
|
std::vector<uint32_t> shape;
|
|
float scale;
|
|
int32_t zeroPoint;
|
|
|
|
OperandTypeParams(Type type, std::vector<uint32_t> shape, float scale, int32_t zeroPoint)
|
|
: type(type), shape(shape), scale(scale), zeroPoint(zeroPoint) {}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
using ::testing::Each;
|
|
using ::testing::ElementsAreArray;
|
|
using ::testing::FloatNear;
|
|
using ::testing::Matcher;
|
|
|
|
class QuantizedLSTMOpModel {
|
|
public:
|
|
QuantizedLSTMOpModel(const std::vector<OperandTypeParams>& inputOperandTypeParams) {
|
|
std::vector<uint32_t> inputs;
|
|
|
|
for (int i = 0; i < NUM_INPUTS; ++i) {
|
|
const auto& curOTP = inputOperandTypeParams[i];
|
|
OperandType curType(curOTP.type, curOTP.shape, curOTP.scale, curOTP.zeroPoint);
|
|
inputs.push_back(model_.addOperand(&curType));
|
|
}
|
|
|
|
const uint32_t numBatches = inputOperandTypeParams[0].shape[0];
|
|
inputSize_ = inputOperandTypeParams[0].shape[0];
|
|
const uint32_t outputSize =
|
|
inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor].shape[1];
|
|
outputSize_ = outputSize;
|
|
|
|
std::vector<uint32_t> outputs;
|
|
OperandType cellStateOutOperandType(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize},
|
|
1. / 2048., 0);
|
|
outputs.push_back(model_.addOperand(&cellStateOutOperandType));
|
|
OperandType outputOperandType(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize},
|
|
1. / 128., 128);
|
|
outputs.push_back(model_.addOperand(&outputOperandType));
|
|
|
|
model_.addOperation(ANEURALNETWORKS_QUANTIZED_16BIT_LSTM, inputs, outputs);
|
|
model_.identifyInputsAndOutputs(inputs, outputs);
|
|
|
|
initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kInputTensor], &input_);
|
|
initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevOutputTensor],
|
|
&prevOutput_);
|
|
initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor],
|
|
&prevCellState_);
|
|
|
|
cellStateOut_.resize(numBatches * outputSize, 0);
|
|
output_.resize(numBatches * outputSize, 0);
|
|
|
|
model_.finish();
|
|
}
|
|
|
|
void invoke() {
|
|
ASSERT_TRUE(model_.isValid());
|
|
|
|
Compilation compilation(&model_);
|
|
compilation.finish();
|
|
Execution execution(&compilation);
|
|
|
|
// Set all the inputs.
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputTensor, input_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToInputWeightsTensor,
|
|
inputToInputWeights_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToForgetWeightsTensor,
|
|
inputToForgetWeights_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToCellWeightsTensor,
|
|
inputToCellWeights_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToOutputWeightsTensor,
|
|
inputToOutputWeights_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToInputWeightsTensor,
|
|
recurrentToInputWeights_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToForgetWeightsTensor,
|
|
recurrentToForgetWeights_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToCellWeightsTensor,
|
|
recurrentToCellWeights_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToOutputWeightsTensor,
|
|
recurrentToOutputWeights_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(
|
|
setInputTensor(&execution, QuantizedLSTMCell::kInputGateBiasTensor, inputGateBias_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kForgetGateBiasTensor,
|
|
forgetGateBias_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kCellGateBiasTensor, cellGateBias_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kOutputGateBiasTensor,
|
|
outputGateBias_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(
|
|
setInputTensor(&execution, QuantizedLSTMCell::kPrevCellStateTensor, prevCellState_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kPrevOutputTensor, prevOutput_),
|
|
Result::NO_ERROR);
|
|
// Set all the outputs.
|
|
ASSERT_EQ(
|
|
setOutputTensor(&execution, QuantizedLSTMCell::kCellStateOutTensor, &cellStateOut_),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(setOutputTensor(&execution, QuantizedLSTMCell::kOutputTensor, &output_),
|
|
Result::NO_ERROR);
|
|
|
|
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
|
|
|
|
// Put state outputs into inputs for the next step
|
|
prevOutput_ = output_;
|
|
prevCellState_ = cellStateOut_;
|
|
}
|
|
|
|
int inputSize() { return inputSize_; }
|
|
|
|
int outputSize() { return outputSize_; }
|
|
|
|
void setInput(const std::vector<uint8_t>& input) { input_ = input; }
|
|
|
|
void setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,
|
|
std::vector<uint8_t> inputToForgetWeights,
|
|
std::vector<uint8_t> inputToCellWeights,
|
|
std::vector<uint8_t> inputToOutputWeights,
|
|
std::vector<uint8_t> recurrentToInputWeights,
|
|
std::vector<uint8_t> recurrentToForgetWeights,
|
|
std::vector<uint8_t> recurrentToCellWeights,
|
|
std::vector<uint8_t> recurrentToOutputWeights,
|
|
std::vector<int32_t> inputGateBias,
|
|
std::vector<int32_t> forgetGateBias,
|
|
std::vector<int32_t> cellGateBias, //
|
|
std::vector<int32_t> outputGateBias) {
|
|
inputToInputWeights_ = inputToInputWeights;
|
|
inputToForgetWeights_ = inputToForgetWeights;
|
|
inputToCellWeights_ = inputToCellWeights;
|
|
inputToOutputWeights_ = inputToOutputWeights;
|
|
recurrentToInputWeights_ = recurrentToInputWeights;
|
|
recurrentToForgetWeights_ = recurrentToForgetWeights;
|
|
recurrentToCellWeights_ = recurrentToCellWeights;
|
|
recurrentToOutputWeights_ = recurrentToOutputWeights;
|
|
inputGateBias_ = inputGateBias;
|
|
forgetGateBias_ = forgetGateBias;
|
|
cellGateBias_ = cellGateBias;
|
|
outputGateBias_ = outputGateBias;
|
|
}
|
|
|
|
template <typename T>
|
|
void initializeInputData(OperandTypeParams params, std::vector<T>* vec) {
|
|
int size = 1;
|
|
for (int d : params.shape) {
|
|
size *= d;
|
|
}
|
|
vec->clear();
|
|
vec->resize(size, params.zeroPoint);
|
|
}
|
|
|
|
std::vector<uint8_t> getOutput() { return output_; }
|
|
|
|
private:
|
|
static constexpr int NUM_INPUTS = 15;
|
|
static constexpr int NUM_OUTPUTS = 2;
|
|
|
|
Model model_;
|
|
// Inputs
|
|
std::vector<uint8_t> input_;
|
|
std::vector<uint8_t> inputToInputWeights_;
|
|
std::vector<uint8_t> inputToForgetWeights_;
|
|
std::vector<uint8_t> inputToCellWeights_;
|
|
std::vector<uint8_t> inputToOutputWeights_;
|
|
std::vector<uint8_t> recurrentToInputWeights_;
|
|
std::vector<uint8_t> recurrentToForgetWeights_;
|
|
std::vector<uint8_t> recurrentToCellWeights_;
|
|
std::vector<uint8_t> recurrentToOutputWeights_;
|
|
std::vector<int32_t> inputGateBias_;
|
|
std::vector<int32_t> forgetGateBias_;
|
|
std::vector<int32_t> cellGateBias_;
|
|
std::vector<int32_t> outputGateBias_;
|
|
std::vector<int16_t> prevCellState_;
|
|
std::vector<uint8_t> prevOutput_;
|
|
// Outputs
|
|
std::vector<int16_t> cellStateOut_;
|
|
std::vector<uint8_t> output_;
|
|
|
|
int inputSize_;
|
|
int outputSize_;
|
|
|
|
template <typename T>
|
|
Result setInputTensor(Execution* execution, int tensor, const std::vector<T>& data) {
|
|
return execution->setInput(tensor, data.data(), sizeof(T) * data.size());
|
|
}
|
|
template <typename T>
|
|
Result setOutputTensor(Execution* execution, int tensor, std::vector<T>* data) {
|
|
return execution->setOutput(tensor, data->data(), sizeof(T) * data->size());
|
|
}
|
|
};
|
|
|
|
class QuantizedLstmTest : public ::testing::Test {
|
|
protected:
|
|
void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
|
|
const std::vector<std::vector<uint8_t>>& output,
|
|
QuantizedLSTMOpModel* lstm) {
|
|
const int numBatches = input.size();
|
|
EXPECT_GT(numBatches, 0);
|
|
const int inputSize = lstm->inputSize();
|
|
EXPECT_GT(inputSize, 0);
|
|
const int inputSequenceSize = input[0].size() / inputSize;
|
|
EXPECT_GT(inputSequenceSize, 0);
|
|
for (int i = 0; i < inputSequenceSize; ++i) {
|
|
std::vector<uint8_t> inputStep;
|
|
for (int b = 0; b < numBatches; ++b) {
|
|
const uint8_t* batchStart = input[b].data() + i * inputSize;
|
|
const uint8_t* batchEnd = batchStart + inputSize;
|
|
inputStep.insert(inputStep.end(), batchStart, batchEnd);
|
|
}
|
|
lstm->setInput(inputStep);
|
|
lstm->invoke();
|
|
|
|
const int outputSize = lstm->outputSize();
|
|
std::vector<float> expected;
|
|
for (int b = 0; b < numBatches; ++b) {
|
|
const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
|
|
const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
|
|
expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
|
|
}
|
|
EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
|
|
}
|
|
}
|
|
};
|
|
|
|
// Inputs and weights in this test are random and the test only checks that the
|
|
// outputs are equal to outputs obtained from running TF Lite version of
|
|
// quantized LSTM on the same inputs.
|
|
TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
|
|
const int numBatches = 2;
|
|
const int inputSize = 2;
|
|
const int outputSize = 4;
|
|
|
|
float weightsScale = 0.00408021;
|
|
int weightsZeroPoint = 100;
|
|
// OperandType biasOperandType(Type::TENSOR_INT32, input_shapes[3],
|
|
// weightsScale / 128., 0);
|
|
// inputs.push_back(model_.addOperand(&biasOperandType));
|
|
// OperandType prevCellStateOperandType(Type::TENSOR_QUANT16_SYMM, input_shapes[4],
|
|
// 1. / 2048., 0);
|
|
// inputs.push_back(model_.addOperand(&prevCellStateOperandType));
|
|
|
|
QuantizedLSTMOpModel lstm({
|
|
// input
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, inputSize}, 1. / 128., 128),
|
|
// inputToInputWeights
|
|
// inputToForgetWeights
|
|
// inputToCellWeights
|
|
// inputToOutputWeights
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
|
|
weightsZeroPoint),
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
|
|
weightsZeroPoint),
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
|
|
weightsZeroPoint),
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
|
|
weightsZeroPoint),
|
|
// recurrentToInputWeights
|
|
// recurrentToForgetWeights
|
|
// recurrentToCellWeights
|
|
// recurrentToOutputWeights
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
|
|
weightsZeroPoint),
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
|
|
weightsZeroPoint),
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
|
|
weightsZeroPoint),
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
|
|
weightsZeroPoint),
|
|
// inputGateBias
|
|
// forgetGateBias
|
|
// cellGateBias
|
|
// outputGateBias
|
|
OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
|
|
OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
|
|
OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
|
|
OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
|
|
// prevCellState
|
|
OperandTypeParams(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 1. / 2048., 0),
|
|
// prevOutput
|
|
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 1. / 128., 128),
|
|
});
|
|
|
|
lstm.setWeightsAndBiases(
|
|
// inputToInputWeights
|
|
{146, 250, 235, 171, 10, 218, 171, 108},
|
|
// inputToForgetWeights
|
|
{24, 50, 132, 179, 158, 110, 3, 169},
|
|
// inputToCellWeights
|
|
{133, 34, 29, 49, 206, 109, 54, 183},
|
|
// inputToOutputWeights
|
|
{195, 187, 11, 99, 109, 10, 218, 48},
|
|
// recurrentToInputWeights
|
|
{254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26},
|
|
// recurrentToForgetWeights
|
|
{137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253},
|
|
// recurrentToCellWeights
|
|
{172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216},
|
|
// recurrentToOutputWeights
|
|
{106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98},
|
|
// inputGateBias
|
|
{-7876, 13488, -726, 32839},
|
|
// forgetGateBias
|
|
{9206, -46884, -11693, -38724},
|
|
// cellGateBias
|
|
{39481, 48624, 48976, -21419},
|
|
// outputGateBias
|
|
{-58999, -17050, -41852, -40538});
|
|
|
|
// LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
|
|
std::vector<std::vector<uint8_t>> lstmInput;
|
|
// clang-format off
|
|
lstmInput = {{154, 166,
|
|
166, 179,
|
|
141, 141},
|
|
{100, 200,
|
|
50, 150,
|
|
111, 222}};
|
|
// clang-format on
|
|
|
|
// LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
|
|
std::vector<std::vector<uint8_t>> lstmGoldenOutput;
|
|
// clang-format off
|
|
lstmGoldenOutput = {{136, 150, 140, 115,
|
|
140, 151, 146, 112,
|
|
139, 153, 146, 114},
|
|
{135, 152, 138, 112,
|
|
136, 156, 142, 112,
|
|
141, 154, 146, 108}};
|
|
// clang-format on
|
|
VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
|
|
};
|
|
|
|
} // namespace wrapper
|
|
} // namespace nn
|
|
} // namespace android
|