You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
428 lines
16 KiB
428 lines
16 KiB
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include <android-base/logging.h>
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "LSTM.h"
|
|
#include "NeuralNetworksWrapper.h"
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
namespace wrapper {
|
|
|
|
using ::testing::Each;
|
|
using ::testing::FloatNear;
|
|
using ::testing::Matcher;
|
|
|
|
namespace {
|
|
|
|
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
|
|
float max_abs_error = 1.e-6) {
|
|
std::vector<Matcher<float>> matchers;
|
|
matchers.reserve(values.size());
|
|
for (const float& v : values) {
|
|
matchers.emplace_back(FloatNear(v, max_abs_error));
|
|
}
|
|
return matchers;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
|
|
ACTION(Input) \
|
|
ACTION(InputToInputWeights) \
|
|
ACTION(InputToCellWeights) \
|
|
ACTION(InputToForgetWeights) \
|
|
ACTION(InputToOutputWeights) \
|
|
ACTION(RecurrentToInputWeights) \
|
|
ACTION(RecurrentToCellWeights) \
|
|
ACTION(RecurrentToForgetWeights) \
|
|
ACTION(RecurrentToOutputWeights) \
|
|
ACTION(CellToInputWeights) \
|
|
ACTION(CellToForgetWeights) \
|
|
ACTION(CellToOutputWeights) \
|
|
ACTION(InputGateBias) \
|
|
ACTION(CellGateBias) \
|
|
ACTION(ForgetGateBias) \
|
|
ACTION(OutputGateBias) \
|
|
ACTION(ProjectionWeights) \
|
|
ACTION(ProjectionBias) \
|
|
ACTION(OutputStateIn) \
|
|
ACTION(CellStateIn)
|
|
|
|
#define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
|
|
ACTION(InputLayerNormWeights) \
|
|
ACTION(ForgetLayerNormWeights) \
|
|
ACTION(CellLayerNormWeights) \
|
|
ACTION(OutputLayerNormWeights)
|
|
|
|
// For all output and intermediate states
|
|
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
|
|
ACTION(ScratchBuffer) \
|
|
ACTION(OutputStateOut) \
|
|
ACTION(CellStateOut) \
|
|
ACTION(Output)
|
|
|
|
class LayerNormLSTMOpModel {
|
|
public:
|
|
LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
|
|
bool use_cifg, bool use_peephole, bool use_projection_weights,
|
|
bool use_projection_bias, float cell_clip, float proj_clip,
|
|
const std::vector<std::vector<uint32_t>>& input_shapes0)
|
|
: n_input_(n_input),
|
|
n_output_(n_output),
|
|
use_cifg_(use_cifg),
|
|
use_peephole_(use_peephole),
|
|
use_projection_weights_(use_projection_weights),
|
|
use_projection_bias_(use_projection_bias),
|
|
activation_(ActivationFn::kActivationTanh),
|
|
cell_clip_(cell_clip),
|
|
proj_clip_(proj_clip) {
|
|
std::vector<uint32_t> inputs;
|
|
std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
|
|
|
|
auto it = input_shapes.begin();
|
|
|
|
// Input and weights
|
|
#define AddInput(X) \
|
|
CHECK(it != input_shapes.end()); \
|
|
OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
|
|
inputs.push_back(model_.addOperand(&X##OpndTy));
|
|
|
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
|
|
|
|
// Parameters
|
|
OperandType ActivationOpndTy(Type::INT32, {});
|
|
inputs.push_back(model_.addOperand(&ActivationOpndTy));
|
|
OperandType CellClipOpndTy(Type::FLOAT32, {});
|
|
inputs.push_back(model_.addOperand(&CellClipOpndTy));
|
|
OperandType ProjClipOpndTy(Type::FLOAT32, {});
|
|
inputs.push_back(model_.addOperand(&ProjClipOpndTy));
|
|
|
|
FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
|
|
|
|
#undef AddOperand
|
|
|
|
// Output and other intermediate state
|
|
std::vector<std::vector<uint32_t>> output_shapes{
|
|
{n_batch, n_cell * (use_cifg ? 3 : 4)},
|
|
{n_batch, n_output},
|
|
{n_batch, n_cell},
|
|
{n_batch, n_output},
|
|
};
|
|
std::vector<uint32_t> outputs;
|
|
|
|
auto it2 = output_shapes.begin();
|
|
|
|
#define AddOutput(X) \
|
|
CHECK(it2 != output_shapes.end()); \
|
|
OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
|
|
outputs.push_back(model_.addOperand(&X##OpndTy));
|
|
|
|
FOR_ALL_OUTPUT_TENSORS(AddOutput);
|
|
|
|
#undef AddOutput
|
|
|
|
model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
|
|
model_.identifyInputsAndOutputs(inputs, outputs);
|
|
|
|
Input_.insert(Input_.end(), n_batch * n_input, 0.f);
|
|
OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
|
|
CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
|
|
|
|
auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
|
|
uint32_t sz = 1;
|
|
for (uint32_t d : dims) {
|
|
sz *= d;
|
|
}
|
|
return sz;
|
|
};
|
|
|
|
it2 = output_shapes.begin();
|
|
|
|
#define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
|
|
|
|
FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
|
|
|
|
#undef ReserveOutput
|
|
|
|
model_.finish();
|
|
}
|
|
|
|
#define DefineSetter(X) \
|
|
void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
|
|
|
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
|
|
FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
|
|
|
|
#undef DefineSetter
|
|
|
|
void ResetOutputState() {
|
|
std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
|
|
std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
|
|
}
|
|
|
|
void ResetCellState() {
|
|
std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
|
|
std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
|
|
}
|
|
|
|
void SetInput(int offset, const float* begin, const float* end) {
|
|
for (; begin != end; begin++, offset++) {
|
|
Input_[offset] = *begin;
|
|
}
|
|
}
|
|
|
|
uint32_t num_inputs() const { return n_input_; }
|
|
uint32_t num_outputs() const { return n_output_; }
|
|
|
|
const std::vector<float>& GetOutput() const { return Output_; }
|
|
|
|
void Invoke() {
|
|
ASSERT_TRUE(model_.isValid());
|
|
|
|
OutputStateIn_.swap(OutputStateOut_);
|
|
CellStateIn_.swap(CellStateOut_);
|
|
|
|
Compilation compilation(&model_);
|
|
compilation.finish();
|
|
Execution execution(&compilation);
|
|
#define SetInputOrWeight(X) \
|
|
ASSERT_EQ( \
|
|
execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
|
|
Result::NO_ERROR);
|
|
|
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
|
|
FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
|
|
|
|
#undef SetInputOrWeight
|
|
|
|
#define SetOutput(X) \
|
|
ASSERT_EQ( \
|
|
execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
|
|
Result::NO_ERROR);
|
|
|
|
FOR_ALL_OUTPUT_TENSORS(SetOutput);
|
|
|
|
#undef SetOutput
|
|
|
|
if (use_cifg_) {
|
|
execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
|
|
execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
|
|
}
|
|
|
|
if (use_peephole_) {
|
|
if (use_cifg_) {
|
|
execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
|
|
}
|
|
} else {
|
|
execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
|
|
execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
|
|
execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
|
|
}
|
|
|
|
if (use_projection_weights_) {
|
|
if (!use_projection_bias_) {
|
|
execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
|
|
}
|
|
} else {
|
|
execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
|
|
execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
|
|
}
|
|
|
|
ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
|
|
Result::NO_ERROR);
|
|
ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
|
|
Result::NO_ERROR);
|
|
|
|
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
|
|
}
|
|
|
|
private:
|
|
Model model_;
|
|
// Execution execution_;
|
|
const uint32_t n_input_;
|
|
const uint32_t n_output_;
|
|
|
|
const bool use_cifg_;
|
|
const bool use_peephole_;
|
|
const bool use_projection_weights_;
|
|
const bool use_projection_bias_;
|
|
|
|
const int activation_;
|
|
const float cell_clip_;
|
|
const float proj_clip_;
|
|
|
|
#define DefineTensor(X) std::vector<float> X##_;
|
|
|
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
|
|
FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
|
|
FOR_ALL_OUTPUT_TENSORS(DefineTensor);
|
|
|
|
#undef DefineTensor
|
|
};
|
|
|
|
TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
|
|
const int n_batch = 2;
|
|
const int n_input = 5;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 3;
|
|
|
|
LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
|
|
/*use_cifg=*/false, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/true,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{n_cell}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{n_output, n_cell}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state_in tensor
|
|
{n_batch, n_cell}, // cell_state_in tensor
|
|
|
|
{n_cell}, // input_layer_norm_weights tensor
|
|
{n_cell}, // forget_layer_norm_weights tensor
|
|
{n_cell}, // cell_layer_norm_weights tensor
|
|
{n_cell}, // output_layer_norm_weights tensor
|
|
});
|
|
|
|
lstm.SetInputToInputWeights({0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
|
|
-0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
|
|
|
|
lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
|
|
-0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5});
|
|
|
|
lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
|
|
0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6});
|
|
|
|
lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
|
|
0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4});
|
|
|
|
lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
|
|
|
|
lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
|
|
|
|
lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
|
|
|
|
lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
|
|
|
|
lstm.SetRecurrentToInputWeights(
|
|
{-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
|
|
|
|
lstm.SetRecurrentToCellWeights(
|
|
{-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
|
|
|
|
lstm.SetRecurrentToForgetWeights(
|
|
{-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
|
|
|
|
lstm.SetRecurrentToOutputWeights(
|
|
{0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
|
|
|
|
lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
|
|
lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
|
|
lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
|
|
|
|
lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
|
|
|
|
lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
|
|
lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
|
|
lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
|
|
lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
|
|
|
|
const std::vector<std::vector<float>> lstm_input = {
|
|
{ // Batch0: 3 (input_sequence_size) * 5 (n_input)
|
|
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
|
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
|
0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
|
|
|
|
{ // Batch1: 3 (input_sequence_size) * 5 (n_input)
|
|
0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
|
|
0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
|
|
0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
|
|
};
|
|
|
|
const std::vector<std::vector<float>> lstm_golden_output = {
|
|
{
|
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
|
0.0244077, 0.128027, -0.00170918, // seq 0
|
|
0.0137642, 0.140751, 0.0395835, // seq 1
|
|
-0.00459231, 0.155278, 0.0837377, // seq 2
|
|
},
|
|
{
|
|
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
|
-0.00692428, 0.0848741, 0.063445, // seq 0
|
|
-0.00403912, 0.139963, 0.072681, // seq 1
|
|
0.00752706, 0.161903, 0.0561371, // seq 2
|
|
}};
|
|
|
|
// Resetting cell_state and output_state
|
|
lstm.ResetCellState();
|
|
lstm.ResetOutputState();
|
|
|
|
const int input_sequence_size = lstm_input[0].size() / n_input;
|
|
for (int i = 0; i < input_sequence_size; i++) {
|
|
for (int b = 0; b < n_batch; ++b) {
|
|
const float* batch_start = lstm_input[b].data() + i * n_input;
|
|
const float* batch_end = batch_start + n_input;
|
|
|
|
lstm.SetInput(b * n_input, batch_start, batch_end);
|
|
}
|
|
|
|
lstm.Invoke();
|
|
|
|
std::vector<float> expected;
|
|
for (int b = 0; b < n_batch; ++b) {
|
|
const float* golden_start = lstm_golden_output[b].data() + i * n_output;
|
|
const float* golden_end = golden_start + n_output;
|
|
expected.insert(expected.end(), golden_start, golden_end);
|
|
}
|
|
EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
|
}
|
|
}
|
|
|
|
} // namespace wrapper
|
|
} // namespace nn
|
|
} // namespace android
|