You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

379 lines
16 KiB

/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "NeuralNetworksWrapper.h"
#include "QuantizedLSTM.h"
namespace android {
namespace nn {
namespace wrapper {
namespace {
struct OperandTypeParams {
Type type;
std::vector<uint32_t> shape;
float scale;
int32_t zeroPoint;
OperandTypeParams(Type type, std::vector<uint32_t> shape, float scale, int32_t zeroPoint)
: type(type), shape(shape), scale(scale), zeroPoint(zeroPoint) {}
};
} // namespace
using ::testing::Each;
using ::testing::ElementsAreArray;
using ::testing::FloatNear;
using ::testing::Matcher;
class QuantizedLSTMOpModel {
public:
QuantizedLSTMOpModel(const std::vector<OperandTypeParams>& inputOperandTypeParams) {
std::vector<uint32_t> inputs;
for (int i = 0; i < NUM_INPUTS; ++i) {
const auto& curOTP = inputOperandTypeParams[i];
OperandType curType(curOTP.type, curOTP.shape, curOTP.scale, curOTP.zeroPoint);
inputs.push_back(model_.addOperand(&curType));
}
const uint32_t numBatches = inputOperandTypeParams[0].shape[0];
inputSize_ = inputOperandTypeParams[0].shape[0];
const uint32_t outputSize =
inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor].shape[1];
outputSize_ = outputSize;
std::vector<uint32_t> outputs;
OperandType cellStateOutOperandType(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize},
1. / 2048., 0);
outputs.push_back(model_.addOperand(&cellStateOutOperandType));
OperandType outputOperandType(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize},
1. / 128., 128);
outputs.push_back(model_.addOperand(&outputOperandType));
model_.addOperation(ANEURALNETWORKS_QUANTIZED_16BIT_LSTM, inputs, outputs);
model_.identifyInputsAndOutputs(inputs, outputs);
initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kInputTensor], &input_);
initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevOutputTensor],
&prevOutput_);
initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor],
&prevCellState_);
cellStateOut_.resize(numBatches * outputSize, 0);
output_.resize(numBatches * outputSize, 0);
model_.finish();
}
void invoke() {
ASSERT_TRUE(model_.isValid());
Compilation compilation(&model_);
compilation.finish();
Execution execution(&compilation);
// Set all the inputs.
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputTensor, input_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToInputWeightsTensor,
inputToInputWeights_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToForgetWeightsTensor,
inputToForgetWeights_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToCellWeightsTensor,
inputToCellWeights_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToOutputWeightsTensor,
inputToOutputWeights_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToInputWeightsTensor,
recurrentToInputWeights_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToForgetWeightsTensor,
recurrentToForgetWeights_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToCellWeightsTensor,
recurrentToCellWeights_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToOutputWeightsTensor,
recurrentToOutputWeights_),
Result::NO_ERROR);
ASSERT_EQ(
setInputTensor(&execution, QuantizedLSTMCell::kInputGateBiasTensor, inputGateBias_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kForgetGateBiasTensor,
forgetGateBias_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kCellGateBiasTensor, cellGateBias_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kOutputGateBiasTensor,
outputGateBias_),
Result::NO_ERROR);
ASSERT_EQ(
setInputTensor(&execution, QuantizedLSTMCell::kPrevCellStateTensor, prevCellState_),
Result::NO_ERROR);
ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kPrevOutputTensor, prevOutput_),
Result::NO_ERROR);
// Set all the outputs.
ASSERT_EQ(
setOutputTensor(&execution, QuantizedLSTMCell::kCellStateOutTensor, &cellStateOut_),
Result::NO_ERROR);
ASSERT_EQ(setOutputTensor(&execution, QuantizedLSTMCell::kOutputTensor, &output_),
Result::NO_ERROR);
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
// Put state outputs into inputs for the next step
prevOutput_ = output_;
prevCellState_ = cellStateOut_;
}
int inputSize() { return inputSize_; }
int outputSize() { return outputSize_; }
void setInput(const std::vector<uint8_t>& input) { input_ = input; }
void setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,
std::vector<uint8_t> inputToForgetWeights,
std::vector<uint8_t> inputToCellWeights,
std::vector<uint8_t> inputToOutputWeights,
std::vector<uint8_t> recurrentToInputWeights,
std::vector<uint8_t> recurrentToForgetWeights,
std::vector<uint8_t> recurrentToCellWeights,
std::vector<uint8_t> recurrentToOutputWeights,
std::vector<int32_t> inputGateBias,
std::vector<int32_t> forgetGateBias,
std::vector<int32_t> cellGateBias, //
std::vector<int32_t> outputGateBias) {
inputToInputWeights_ = inputToInputWeights;
inputToForgetWeights_ = inputToForgetWeights;
inputToCellWeights_ = inputToCellWeights;
inputToOutputWeights_ = inputToOutputWeights;
recurrentToInputWeights_ = recurrentToInputWeights;
recurrentToForgetWeights_ = recurrentToForgetWeights;
recurrentToCellWeights_ = recurrentToCellWeights;
recurrentToOutputWeights_ = recurrentToOutputWeights;
inputGateBias_ = inputGateBias;
forgetGateBias_ = forgetGateBias;
cellGateBias_ = cellGateBias;
outputGateBias_ = outputGateBias;
}
template <typename T>
void initializeInputData(OperandTypeParams params, std::vector<T>* vec) {
int size = 1;
for (int d : params.shape) {
size *= d;
}
vec->clear();
vec->resize(size, params.zeroPoint);
}
std::vector<uint8_t> getOutput() { return output_; }
private:
static constexpr int NUM_INPUTS = 15;
static constexpr int NUM_OUTPUTS = 2;
Model model_;
// Inputs
std::vector<uint8_t> input_;
std::vector<uint8_t> inputToInputWeights_;
std::vector<uint8_t> inputToForgetWeights_;
std::vector<uint8_t> inputToCellWeights_;
std::vector<uint8_t> inputToOutputWeights_;
std::vector<uint8_t> recurrentToInputWeights_;
std::vector<uint8_t> recurrentToForgetWeights_;
std::vector<uint8_t> recurrentToCellWeights_;
std::vector<uint8_t> recurrentToOutputWeights_;
std::vector<int32_t> inputGateBias_;
std::vector<int32_t> forgetGateBias_;
std::vector<int32_t> cellGateBias_;
std::vector<int32_t> outputGateBias_;
std::vector<int16_t> prevCellState_;
std::vector<uint8_t> prevOutput_;
// Outputs
std::vector<int16_t> cellStateOut_;
std::vector<uint8_t> output_;
int inputSize_;
int outputSize_;
template <typename T>
Result setInputTensor(Execution* execution, int tensor, const std::vector<T>& data) {
return execution->setInput(tensor, data.data(), sizeof(T) * data.size());
}
template <typename T>
Result setOutputTensor(Execution* execution, int tensor, std::vector<T>* data) {
return execution->setOutput(tensor, data->data(), sizeof(T) * data->size());
}
};
class QuantizedLstmTest : public ::testing::Test {
protected:
void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
const std::vector<std::vector<uint8_t>>& output,
QuantizedLSTMOpModel* lstm) {
const int numBatches = input.size();
EXPECT_GT(numBatches, 0);
const int inputSize = lstm->inputSize();
EXPECT_GT(inputSize, 0);
const int inputSequenceSize = input[0].size() / inputSize;
EXPECT_GT(inputSequenceSize, 0);
for (int i = 0; i < inputSequenceSize; ++i) {
std::vector<uint8_t> inputStep;
for (int b = 0; b < numBatches; ++b) {
const uint8_t* batchStart = input[b].data() + i * inputSize;
const uint8_t* batchEnd = batchStart + inputSize;
inputStep.insert(inputStep.end(), batchStart, batchEnd);
}
lstm->setInput(inputStep);
lstm->invoke();
const int outputSize = lstm->outputSize();
std::vector<float> expected;
for (int b = 0; b < numBatches; ++b) {
const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
}
EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
}
}
};
// Inputs and weights in this test are random and the test only checks that the
// outputs are equal to outputs obtained from running TF Lite version of
// quantized LSTM on the same inputs.
TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
const int numBatches = 2;
const int inputSize = 2;
const int outputSize = 4;
float weightsScale = 0.00408021;
int weightsZeroPoint = 100;
// OperandType biasOperandType(Type::TENSOR_INT32, input_shapes[3],
// weightsScale / 128., 0);
// inputs.push_back(model_.addOperand(&biasOperandType));
// OperandType prevCellStateOperandType(Type::TENSOR_QUANT16_SYMM, input_shapes[4],
// 1. / 2048., 0);
// inputs.push_back(model_.addOperand(&prevCellStateOperandType));
QuantizedLSTMOpModel lstm({
// input
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, inputSize}, 1. / 128., 128),
// inputToInputWeights
// inputToForgetWeights
// inputToCellWeights
// inputToOutputWeights
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
weightsZeroPoint),
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
weightsZeroPoint),
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
weightsZeroPoint),
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
weightsZeroPoint),
// recurrentToInputWeights
// recurrentToForgetWeights
// recurrentToCellWeights
// recurrentToOutputWeights
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
weightsZeroPoint),
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
weightsZeroPoint),
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
weightsZeroPoint),
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
weightsZeroPoint),
// inputGateBias
// forgetGateBias
// cellGateBias
// outputGateBias
OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
// prevCellState
OperandTypeParams(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 1. / 2048., 0),
// prevOutput
OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 1. / 128., 128),
});
lstm.setWeightsAndBiases(
// inputToInputWeights
{146, 250, 235, 171, 10, 218, 171, 108},
// inputToForgetWeights
{24, 50, 132, 179, 158, 110, 3, 169},
// inputToCellWeights
{133, 34, 29, 49, 206, 109, 54, 183},
// inputToOutputWeights
{195, 187, 11, 99, 109, 10, 218, 48},
// recurrentToInputWeights
{254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26},
// recurrentToForgetWeights
{137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253},
// recurrentToCellWeights
{172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216},
// recurrentToOutputWeights
{106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98},
// inputGateBias
{-7876, 13488, -726, 32839},
// forgetGateBias
{9206, -46884, -11693, -38724},
// cellGateBias
{39481, 48624, 48976, -21419},
// outputGateBias
{-58999, -17050, -41852, -40538});
// LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
std::vector<std::vector<uint8_t>> lstmInput;
// clang-format off
lstmInput = {{154, 166,
166, 179,
141, 141},
{100, 200,
50, 150,
111, 222}};
// clang-format on
// LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
std::vector<std::vector<uint8_t>> lstmGoldenOutput;
// clang-format off
lstmGoldenOutput = {{136, 150, 140, 115,
140, 151, 146, 112,
139, 153, 146, 114},
{135, 152, 138, 112,
136, 156, 142, 112,
141, 154, 146, 108}};
// clang-format on
VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
};
} // namespace wrapper
} // namespace nn
} // namespace android