You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
12 KiB
250 lines
12 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2017 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H
|
||
|
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H
|
||
|
|
||
|
#include <tensorflow/lite/kernels/internal/tensor_utils.h>
|
||
|
|
||
|
#include <algorithm>
|
||
|
#include <cmath>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "ActivationFunctor.h"
|
||
|
#include "nnapi/Types.h"
|
||
|
|
||
|
namespace android {
|
||
|
namespace nn {
|
||
|
|
||
|
struct LSTMParams {
|
||
|
TfLiteFusedActivation activation;
|
||
|
float cell_clip;
|
||
|
float proj_clip;
|
||
|
bool use_cifg;
|
||
|
bool use_peephole;
|
||
|
bool use_layer_norm;
|
||
|
bool use_projection_weight;
|
||
|
bool use_projection_bias;
|
||
|
bool merge_outputs;
|
||
|
bool time_major;
|
||
|
bool output_state;
|
||
|
};
|
||
|
|
||
|
struct RunTimeOperandInfo;
|
||
|
struct Shape;
|
||
|
|
||
|
class LSTMCell {
|
||
|
public:
|
||
|
LSTMCell(const Operation& operation, RunTimeOperandInfo* operands);
|
||
|
|
||
|
bool Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* scratchShape,
|
||
|
Shape* outputStateShape, Shape* cellStateShape, Shape* outputShape);
|
||
|
bool Eval();
|
||
|
|
||
|
// Input Tensors of size {n_batch, n_input}
|
||
|
static constexpr int kInputTensor = 0;
|
||
|
|
||
|
// Input weight tensors of size: {n_cell, n_input}
|
||
|
static constexpr int kInputToInputWeightsTensor = 1; // Optional
|
||
|
static constexpr int kInputToForgetWeightsTensor = 2;
|
||
|
static constexpr int kInputToCellWeightsTensor = 3;
|
||
|
static constexpr int kInputToOutputWeightsTensor = 4;
|
||
|
|
||
|
// Recurrent weight tensors of size {n_cell, n_output}
|
||
|
static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
|
||
|
static constexpr int kRecurrentToForgetWeightsTensor = 6;
|
||
|
static constexpr int kRecurrentToCellWeightsTensor = 7;
|
||
|
static constexpr int kRecurrentToOutputWeightsTensor = 8;
|
||
|
|
||
|
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
|
||
|
static constexpr int kCellToInputWeightsTensor = 9; // Optional
|
||
|
static constexpr int kCellToForgetWeightsTensor = 10; // Optional
|
||
|
static constexpr int kCellToOutputWeightsTensor = 11; // Optional
|
||
|
|
||
|
// Gates bias tensors of size {n_cell}
|
||
|
static constexpr int kInputGateBiasTensor = 12; // Optional
|
||
|
static constexpr int kForgetGateBiasTensor = 13;
|
||
|
static constexpr int kCellGateBiasTensor = 14;
|
||
|
static constexpr int kOutputGateBiasTensor = 15;
|
||
|
|
||
|
// Projection weight tensor of size {n_output, n_cell}
|
||
|
static constexpr int kProjectionWeightsTensor = 16; // Optional
|
||
|
// Projection bias tensor of size {n_output}
|
||
|
static constexpr int kProjectionBiasTensor = 17; // Optional
|
||
|
|
||
|
static constexpr int kOutputStateInTensor = 18;
|
||
|
static constexpr int kCellStateInTensor = 19;
|
||
|
|
||
|
static constexpr int kActivationParam = 20;
|
||
|
static constexpr int kCellClipParam = 21;
|
||
|
static constexpr int kProjClipParam = 22;
|
||
|
|
||
|
// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
|
||
|
static constexpr int kInputLayerNormWeightsTensor = 23;
|
||
|
static constexpr int kForgetLayerNormWeightsTensor = 24;
|
||
|
static constexpr int kCellLayerNormWeightsTensor = 25;
|
||
|
static constexpr int kOutputLayerNormWeightsTensor = 26;
|
||
|
|
||
|
// Output tensors.
|
||
|
static constexpr int kScratchBufferTensor = 0;
|
||
|
static constexpr int kOutputStateOutTensor = 1;
|
||
|
static constexpr int kCellStateOutTensor = 2;
|
||
|
static constexpr int kOutputTensor = 3;
|
||
|
|
||
|
static bool LSTMEvalFloat32(
|
||
|
const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
|
||
|
const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
|
||
|
const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
|
||
|
const Shape& input_to_output_weights_shape,
|
||
|
const float* recurrent_to_input_weights_buffer,
|
||
|
const float* recurrent_to_forget_weights_buffer,
|
||
|
const float* recurrent_to_cell_weights_buffer,
|
||
|
const float* recurrent_to_output_weights_buffer,
|
||
|
const Shape& recurrent_to_output_weights_shape,
|
||
|
const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
|
||
|
const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
|
||
|
const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
|
||
|
const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
|
||
|
const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
|
||
|
const float* cell_bias_buffer, const float* output_gate_bias_buffer,
|
||
|
const float* projection_weights_buffer, const float* projection_bias_buffer,
|
||
|
const float* output_state_in_buffer, const float* cell_state_in_buffer,
|
||
|
const float* input_layer_norm_weights_buffer,
|
||
|
const float* forget_layer_norm_weights_buffer,
|
||
|
const float* cell_layer_norm_weights_buffer,
|
||
|
const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
|
||
|
float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
|
||
|
bool timeMajor = true, bool forwardSequence = true);
|
||
|
|
||
|
static bool LSTMEvalFloat16(
|
||
|
const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
|
||
|
const _Float16* input_to_input_weights_buffer,
|
||
|
const _Float16* input_to_forget_weights_buffer,
|
||
|
const _Float16* input_to_cell_weights_buffer,
|
||
|
const _Float16* input_to_output_weights_buffer,
|
||
|
const Shape& input_to_output_weights_shape,
|
||
|
const _Float16* recurrent_to_input_weights_buffer,
|
||
|
const _Float16* recurrent_to_forget_weights_buffer,
|
||
|
const _Float16* recurrent_to_cell_weights_buffer,
|
||
|
const _Float16* recurrent_to_output_weights_buffer,
|
||
|
const Shape& recurrent_to_output_weights_shape,
|
||
|
const _Float16* cell_to_input_weights_buffer,
|
||
|
const _Float16* cell_to_forget_weights_buffer,
|
||
|
const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
|
||
|
const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights,
|
||
|
const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights,
|
||
|
const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer,
|
||
|
const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer,
|
||
|
const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer,
|
||
|
const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer,
|
||
|
const _Float16* input_layer_norm_weights_buffer,
|
||
|
const _Float16* forget_layer_norm_weights_buffer,
|
||
|
const _Float16* cell_layer_norm_weights_buffer,
|
||
|
const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
|
||
|
_Float16* cell_state_out_buffer, _Float16* output_buffer,
|
||
|
_Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true);
|
||
|
|
||
|
static bool LSTMStep(
|
||
|
const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
|
||
|
const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
|
||
|
const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
|
||
|
const Shape& input_to_output_weights_shape,
|
||
|
const float* recurrent_to_input_weights_buffer,
|
||
|
const float* recurrent_to_forget_weights_buffer,
|
||
|
const float* recurrent_to_cell_weights_buffer,
|
||
|
const float* recurrent_to_output_weights_buffer,
|
||
|
const Shape& recurrent_to_output_weights_shape,
|
||
|
const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
|
||
|
const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
|
||
|
const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
|
||
|
const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
|
||
|
const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
|
||
|
const float* cell_bias_buffer, const float* output_gate_bias_buffer,
|
||
|
const float* projection_weights_buffer, const float* projection_bias_buffer,
|
||
|
const float* output_state_in_buffer, const float* cell_state_in_buffer,
|
||
|
const float* input_layer_norm_weights_buffer,
|
||
|
const float* forget_layer_norm_weights_buffer,
|
||
|
const float* cell_layer_norm_weights_buffer,
|
||
|
const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
|
||
|
float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer);
|
||
|
|
||
|
static bool CheckInputTensorDimensions(
|
||
|
const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
|
||
|
const RunTimeOperandInfo* input_to_forget_weights,
|
||
|
const RunTimeOperandInfo* input_to_cell_weights,
|
||
|
const RunTimeOperandInfo* input_to_output_weights,
|
||
|
const RunTimeOperandInfo* recurrent_to_input_weights,
|
||
|
const RunTimeOperandInfo* recurrent_to_forget_weights,
|
||
|
const RunTimeOperandInfo* recurrent_to_cell_weights,
|
||
|
const RunTimeOperandInfo* recurrent_to_output_weights,
|
||
|
const RunTimeOperandInfo* cell_to_input_weights,
|
||
|
const RunTimeOperandInfo* cell_to_forget_weights,
|
||
|
const RunTimeOperandInfo* cell_to_output_weights,
|
||
|
const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias,
|
||
|
const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias,
|
||
|
const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias,
|
||
|
const RunTimeOperandInfo* input_layer_norm_weights,
|
||
|
const RunTimeOperandInfo* forget_layer_norm_weights,
|
||
|
const RunTimeOperandInfo* cell_layer_norm_weights,
|
||
|
const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input,
|
||
|
uint32_t n_output, uint32_t n_cell, LSTMParams* params);
|
||
|
|
||
|
private:
|
||
|
LSTMParams params_;
|
||
|
const RunTimeOperandInfo* input_;
|
||
|
|
||
|
const RunTimeOperandInfo* input_to_input_weights_;
|
||
|
const RunTimeOperandInfo* input_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* input_to_cell_weights_;
|
||
|
const RunTimeOperandInfo* input_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* recurrent_to_input_weights_;
|
||
|
const RunTimeOperandInfo* recurrent_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* recurrent_to_cell_weights_;
|
||
|
const RunTimeOperandInfo* recurrent_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* cell_to_input_weights_;
|
||
|
const RunTimeOperandInfo* cell_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* cell_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* input_gate_bias_;
|
||
|
const RunTimeOperandInfo* forget_gate_bias_;
|
||
|
const RunTimeOperandInfo* cell_bias_;
|
||
|
const RunTimeOperandInfo* output_gate_bias_;
|
||
|
|
||
|
const RunTimeOperandInfo* projection_weights_;
|
||
|
const RunTimeOperandInfo* projection_bias_;
|
||
|
|
||
|
const RunTimeOperandInfo* output_state_in_;
|
||
|
const RunTimeOperandInfo* cell_state_in_;
|
||
|
|
||
|
const RunTimeOperandInfo* input_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* forget_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* cell_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* output_layer_norm_weights_;
|
||
|
|
||
|
RunTimeOperandInfo* output_state_out_;
|
||
|
RunTimeOperandInfo* cell_state_out_;
|
||
|
RunTimeOperandInfo* output_;
|
||
|
|
||
|
RunTimeOperandInfo* scratch_buffer_;
|
||
|
};
|
||
|
|
||
|
} // namespace nn
|
||
|
} // namespace android
|
||
|
|
||
|
#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H
|