/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H #include #include #include #include #include "ActivationFunctor.h" #include "nnapi/Types.h" namespace android { namespace nn { struct LSTMParams { TfLiteFusedActivation activation; float cell_clip; float proj_clip; bool use_cifg; bool use_peephole; bool use_layer_norm; bool use_projection_weight; bool use_projection_bias; bool merge_outputs; bool time_major; bool output_state; }; struct RunTimeOperandInfo; struct Shape; class LSTMCell { public: LSTMCell(const Operation& operation, RunTimeOperandInfo* operands); bool Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape, Shape* outputShape); bool Eval(); // Input Tensors of size {n_batch, n_input} static constexpr int kInputTensor = 0; // Input weight tensors of size: {n_cell, n_input} static constexpr int kInputToInputWeightsTensor = 1; // Optional static constexpr int kInputToForgetWeightsTensor = 2; static constexpr int kInputToCellWeightsTensor = 3; static constexpr int kInputToOutputWeightsTensor = 4; // Recurrent weight tensors of size {n_cell, n_output} static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional static constexpr int kRecurrentToForgetWeightsTensor = 6; static constexpr int kRecurrentToCellWeightsTensor = 7; static constexpr int kRecurrentToOutputWeightsTensor = 8; // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. static constexpr int kCellToInputWeightsTensor = 9; // Optional static constexpr int kCellToForgetWeightsTensor = 10; // Optional static constexpr int kCellToOutputWeightsTensor = 11; // Optional // Gates bias tensors of size {n_cell} static constexpr int kInputGateBiasTensor = 12; // Optional static constexpr int kForgetGateBiasTensor = 13; static constexpr int kCellGateBiasTensor = 14; static constexpr int kOutputGateBiasTensor = 15; // Projection weight tensor of size {n_output, n_cell} static constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} static constexpr int kProjectionBiasTensor = 17; // Optional static constexpr int kOutputStateInTensor = 18; static constexpr int kCellStateInTensor = 19; static constexpr int kActivationParam = 20; static constexpr int kCellClipParam = 21; static constexpr int kProjClipParam = 22; // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. static constexpr int kInputLayerNormWeightsTensor = 23; static constexpr int kForgetLayerNormWeightsTensor = 24; static constexpr int kCellLayerNormWeightsTensor = 25; static constexpr int kOutputLayerNormWeightsTensor = 26; // Output tensors. static constexpr int kScratchBufferTensor = 0; static constexpr int kOutputStateOutTensor = 1; static constexpr int kCellStateOutTensor = 2; static constexpr int kOutputTensor = 3; static bool LSTMEvalFloat32( const LSTMParams& params, const float* input_buffer, const Shape& input_shape, const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer, const float* recurrent_to_forget_weights_buffer, const float* recurrent_to_cell_weights_buffer, const float* recurrent_to_output_weights_buffer, const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer, const float* aux_input_buffer, const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights, const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights, const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, const float* cell_bias_buffer, const float* output_gate_bias_buffer, const float* projection_weights_buffer, const float* projection_bias_buffer, const float* output_state_in_buffer, const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true); static bool LSTMEvalFloat16( const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape, const _Float16* input_to_input_weights_buffer, const _Float16* input_to_forget_weights_buffer, const _Float16* input_to_cell_weights_buffer, const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape, const _Float16* recurrent_to_input_weights_buffer, const _Float16* recurrent_to_forget_weights_buffer, const _Float16* recurrent_to_cell_weights_buffer, const _Float16* recurrent_to_output_weights_buffer, const Shape& recurrent_to_output_weights_shape, const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer, const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer, const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights, const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights, const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer, const _Float16* forget_layer_norm_weights_buffer, const _Float16* cell_layer_norm_weights_buffer, const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer, _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true); static bool LSTMStep( const LSTMParams& params, const float* input_buffer, const Shape& input_shape, const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer, const float* recurrent_to_forget_weights_buffer, const float* recurrent_to_cell_weights_buffer, const float* recurrent_to_output_weights_buffer, const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer, const float* aux_input_buffer, const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights, const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights, const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, const float* cell_bias_buffer, const float* output_gate_bias_buffer, const float* projection_weights_buffer, const float* projection_bias_buffer, const float* output_state_in_buffer, const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer); static bool CheckInputTensorDimensions( const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights, const RunTimeOperandInfo* input_to_forget_weights, const RunTimeOperandInfo* input_to_cell_weights, const RunTimeOperandInfo* input_to_output_weights, const RunTimeOperandInfo* recurrent_to_input_weights, const RunTimeOperandInfo* recurrent_to_forget_weights, const RunTimeOperandInfo* recurrent_to_cell_weights, const RunTimeOperandInfo* recurrent_to_output_weights, const RunTimeOperandInfo* cell_to_input_weights, const RunTimeOperandInfo* cell_to_forget_weights, const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias, const RunTimeOperandInfo* input_layer_norm_weights, const RunTimeOperandInfo* forget_layer_norm_weights, const RunTimeOperandInfo* cell_layer_norm_weights, const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output, uint32_t n_cell, LSTMParams* params); private: LSTMParams params_; const RunTimeOperandInfo* input_; const RunTimeOperandInfo* input_to_input_weights_; const RunTimeOperandInfo* input_to_forget_weights_; const RunTimeOperandInfo* input_to_cell_weights_; const RunTimeOperandInfo* input_to_output_weights_; const RunTimeOperandInfo* recurrent_to_input_weights_; const RunTimeOperandInfo* recurrent_to_forget_weights_; const RunTimeOperandInfo* recurrent_to_cell_weights_; const RunTimeOperandInfo* recurrent_to_output_weights_; const RunTimeOperandInfo* cell_to_input_weights_; const RunTimeOperandInfo* cell_to_forget_weights_; const RunTimeOperandInfo* cell_to_output_weights_; const RunTimeOperandInfo* input_gate_bias_; const RunTimeOperandInfo* forget_gate_bias_; const RunTimeOperandInfo* cell_bias_; const RunTimeOperandInfo* output_gate_bias_; const RunTimeOperandInfo* projection_weights_; const RunTimeOperandInfo* projection_bias_; const RunTimeOperandInfo* output_state_in_; const RunTimeOperandInfo* cell_state_in_; const RunTimeOperandInfo* input_layer_norm_weights_; const RunTimeOperandInfo* forget_layer_norm_weights_; const RunTimeOperandInfo* cell_layer_norm_weights_; const RunTimeOperandInfo* output_layer_norm_weights_; RunTimeOperandInfo* output_state_out_; RunTimeOperandInfo* cell_state_out_; RunTimeOperandInfo* output_; RunTimeOperandInfo* scratch_buffer_; }; } // namespace nn } // namespace android #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H