You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
244 lines
11 KiB
244 lines
11 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2019 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
|
||
|
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
|
||
|
|
||
|
#include <algorithm>
|
||
|
#include <cmath>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "ActivationFunctor.h"
|
||
|
#include "LSTM.h"
|
||
|
#include "OperationsUtils.h"
|
||
|
|
||
|
namespace android {
|
||
|
namespace nn {
|
||
|
|
||
|
struct RunTimeOperandInfo;
|
||
|
|
||
|
class BidirectionalSequenceLSTM {
|
||
|
public:
|
||
|
BidirectionalSequenceLSTM(const Operation& operation, RunTimeOperandInfo* operands);
|
||
|
|
||
|
bool Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* fwOutputShape,
|
||
|
Shape* bwOutputShape, Shape* fwOutputActivationState, Shape* fwOutputCellState,
|
||
|
Shape* bwOutputActivationState, Shape* bwOutputCellState);
|
||
|
bool Eval();
|
||
|
|
||
|
// Input Tensors of size {max_time, n_batch, n_input}
|
||
|
static constexpr int kInputTensor = 0;
|
||
|
|
||
|
// Forward LSTM cell tensors.
|
||
|
// Input weight tensors of size: {n_cell, n_input}
|
||
|
static constexpr int kFwInputToInputWeightsTensor = 1; // Optional
|
||
|
static constexpr int kFwInputToForgetWeightsTensor = 2;
|
||
|
static constexpr int kFwInputToCellWeightsTensor = 3;
|
||
|
static constexpr int kFwInputToOutputWeightsTensor = 4;
|
||
|
|
||
|
// Recurrent weight tensors of size {n_cell, n_output}
|
||
|
static constexpr int kFwRecurrentToInputWeightsTensor = 5; // Optional
|
||
|
static constexpr int kFwRecurrentToForgetWeightsTensor = 6;
|
||
|
static constexpr int kFwRecurrentToCellWeightsTensor = 7;
|
||
|
static constexpr int kFwRecurrentToOutputWeightsTensor = 8;
|
||
|
|
||
|
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
|
||
|
static constexpr int kFwCellToInputWeightsTensor = 9; // Optional
|
||
|
static constexpr int kFwCellToForgetWeightsTensor = 10; // Optional
|
||
|
static constexpr int kFwCellToOutputWeightsTensor = 11; // Optional
|
||
|
|
||
|
// Gates bias tensors of size {n_cell}
|
||
|
static constexpr int kFwInputGateBiasTensor = 12; // Optional
|
||
|
static constexpr int kFwForgetGateBiasTensor = 13;
|
||
|
static constexpr int kFwCellGateBiasTensor = 14;
|
||
|
static constexpr int kFwOutputGateBiasTensor = 15;
|
||
|
|
||
|
// Projection weight tensor of size {n_output, n_cell}
|
||
|
static constexpr int kFwProjectionWeightsTensor = 16; // Optional
|
||
|
// Projection bias tensor of size {n_output}
|
||
|
static constexpr int kFwProjectionBiasTensor = 17; // Optional
|
||
|
|
||
|
// Backward LSTM cell tensors.
|
||
|
// Input weight tensors of size: {n_cell, n_input}
|
||
|
static constexpr int kBwInputToInputWeightsTensor = 18; // Optional
|
||
|
static constexpr int kBwInputToForgetWeightsTensor = 19;
|
||
|
static constexpr int kBwInputToCellWeightsTensor = 20;
|
||
|
static constexpr int kBwInputToOutputWeightsTensor = 21;
|
||
|
|
||
|
// Recurrent weight tensors of size {n_cell, n_output}
|
||
|
static constexpr int kBwRecurrentToInputWeightsTensor = 22; // Optional
|
||
|
static constexpr int kBwRecurrentToForgetWeightsTensor = 23;
|
||
|
static constexpr int kBwRecurrentToCellWeightsTensor = 24;
|
||
|
static constexpr int kBwRecurrentToOutputWeightsTensor = 25;
|
||
|
|
||
|
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
|
||
|
static constexpr int kBwCellToInputWeightsTensor = 26; // Optional
|
||
|
static constexpr int kBwCellToForgetWeightsTensor = 27; // Optional
|
||
|
static constexpr int kBwCellToOutputWeightsTensor = 28; // Optional
|
||
|
|
||
|
// Gates bias tensors of size {n_cell}
|
||
|
static constexpr int kBwInputGateBiasTensor = 29; // Optional
|
||
|
static constexpr int kBwForgetGateBiasTensor = 30;
|
||
|
static constexpr int kBwCellGateBiasTensor = 31;
|
||
|
static constexpr int kBwOutputGateBiasTensor = 32;
|
||
|
|
||
|
// Projection weight tensor of size {n_output, n_cell}
|
||
|
static constexpr int kBwProjectionWeightsTensor = 33; // Optional
|
||
|
// Projection bias tensor of size {n_output}
|
||
|
static constexpr int kBwProjectionBiasTensor = 34; // Optional
|
||
|
|
||
|
// Stateful input tensors that are variables and will be modified by the Op.
|
||
|
// Activation state tensors of size {n_batch, n_output}
|
||
|
static constexpr int kFwInputActivationStateTensor = 35;
|
||
|
// Cell state tensors of size {n_batch, n_cell}
|
||
|
static constexpr int kFwInputCellStateTensor = 36;
|
||
|
// Activation state tensors of size {n_batch, n_output}
|
||
|
static constexpr int kBwInputActivationStateTensor = 37;
|
||
|
// Cell state tensors of size {n_batch, n_cell}
|
||
|
static constexpr int kBwInputCellStateTensor = 38;
|
||
|
|
||
|
// Used as auxiliary input and weights when stacking for
|
||
|
// tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
|
||
|
// to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
|
||
|
// (without cross links).
|
||
|
static constexpr int kAuxInputTensor = 39; // Optional
|
||
|
// Forward weights.
|
||
|
static constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional
|
||
|
static constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional
|
||
|
static constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional
|
||
|
static constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional
|
||
|
// Backward weights.
|
||
|
static constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional
|
||
|
static constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional
|
||
|
static constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional
|
||
|
static constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
|
||
|
|
||
|
static constexpr int kActivationParam = 48;
|
||
|
static constexpr int kCellClipParam = 49;
|
||
|
static constexpr int kProjClipParam = 50;
|
||
|
static constexpr int kMergeOutputsParam = 51;
|
||
|
static constexpr int kTimeMajorParam = 52;
|
||
|
|
||
|
// Forward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
|
||
|
static constexpr int kFwInputLayerNormWeightsTensor = 53; // Optional
|
||
|
static constexpr int kFwForgetLayerNormWeightsTensor = 54; // Optional
|
||
|
static constexpr int kFwCellLayerNormWeightsTensor = 55; // Optional
|
||
|
static constexpr int kFwOutputLayerNormWeightsTensor = 56; // Optional
|
||
|
// Backward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
|
||
|
static constexpr int kBwInputLayerNormWeightsTensor = 57; // Optional
|
||
|
static constexpr int kBwForgetLayerNormWeightsTensor = 58; // Optional
|
||
|
static constexpr int kBwCellLayerNormWeightsTensor = 59; // Optional
|
||
|
static constexpr int kBwOutputLayerNormWeightsTensor = 60; // Optional
|
||
|
|
||
|
// Output tensors.
|
||
|
static constexpr int kFwOutputTensor = 0;
|
||
|
static constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set.
|
||
|
|
||
|
static constexpr int kFwOutputActivationStateTensor = 2;
|
||
|
static constexpr int kFwOutputCellStateTensor = 3;
|
||
|
static constexpr int kBwOutputActivationStateTensor = 4;
|
||
|
static constexpr int kBwOutputCellStateTensor = 5;
|
||
|
|
||
|
private:
|
||
|
LSTMParams params_;
|
||
|
Shape fw_scratch_shape_;
|
||
|
Shape bw_scratch_shape_;
|
||
|
|
||
|
const RunTimeOperandInfo* input_;
|
||
|
|
||
|
const RunTimeOperandInfo* aux_input_;
|
||
|
const RunTimeOperandInfo* fw_aux_input_to_input_weights_;
|
||
|
const RunTimeOperandInfo* fw_aux_input_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* fw_aux_input_to_cell_weights_;
|
||
|
const RunTimeOperandInfo* fw_aux_input_to_output_weights_;
|
||
|
const RunTimeOperandInfo* bw_aux_input_to_input_weights_;
|
||
|
const RunTimeOperandInfo* bw_aux_input_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* bw_aux_input_to_cell_weights_;
|
||
|
const RunTimeOperandInfo* bw_aux_input_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* fw_input_to_input_weights_;
|
||
|
const RunTimeOperandInfo* fw_input_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* fw_input_to_cell_weights_;
|
||
|
const RunTimeOperandInfo* fw_input_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* fw_recurrent_to_input_weights_;
|
||
|
const RunTimeOperandInfo* fw_recurrent_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* fw_recurrent_to_cell_weights_;
|
||
|
const RunTimeOperandInfo* fw_recurrent_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* fw_cell_to_input_weights_;
|
||
|
const RunTimeOperandInfo* fw_cell_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* fw_cell_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* fw_input_gate_bias_;
|
||
|
const RunTimeOperandInfo* fw_forget_gate_bias_;
|
||
|
const RunTimeOperandInfo* fw_cell_bias_;
|
||
|
const RunTimeOperandInfo* fw_output_gate_bias_;
|
||
|
|
||
|
const RunTimeOperandInfo* fw_projection_weights_;
|
||
|
const RunTimeOperandInfo* fw_projection_bias_;
|
||
|
|
||
|
const RunTimeOperandInfo* fw_input_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* fw_forget_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* fw_cell_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* fw_output_layer_norm_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* fw_activation_state_;
|
||
|
const RunTimeOperandInfo* fw_cell_state_;
|
||
|
RunTimeOperandInfo* fw_output_;
|
||
|
|
||
|
const RunTimeOperandInfo* bw_input_to_input_weights_;
|
||
|
const RunTimeOperandInfo* bw_input_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* bw_input_to_cell_weights_;
|
||
|
const RunTimeOperandInfo* bw_input_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* bw_recurrent_to_input_weights_;
|
||
|
const RunTimeOperandInfo* bw_recurrent_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* bw_recurrent_to_cell_weights_;
|
||
|
const RunTimeOperandInfo* bw_recurrent_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* bw_cell_to_input_weights_;
|
||
|
const RunTimeOperandInfo* bw_cell_to_forget_weights_;
|
||
|
const RunTimeOperandInfo* bw_cell_to_output_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* bw_input_gate_bias_;
|
||
|
const RunTimeOperandInfo* bw_forget_gate_bias_;
|
||
|
const RunTimeOperandInfo* bw_cell_bias_;
|
||
|
const RunTimeOperandInfo* bw_output_gate_bias_;
|
||
|
|
||
|
const RunTimeOperandInfo* bw_projection_weights_;
|
||
|
const RunTimeOperandInfo* bw_projection_bias_;
|
||
|
|
||
|
const RunTimeOperandInfo* bw_input_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* bw_forget_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* bw_cell_layer_norm_weights_;
|
||
|
const RunTimeOperandInfo* bw_output_layer_norm_weights_;
|
||
|
|
||
|
const RunTimeOperandInfo* bw_activation_state_;
|
||
|
const RunTimeOperandInfo* bw_cell_state_;
|
||
|
RunTimeOperandInfo* bw_output_;
|
||
|
|
||
|
RunTimeOperandInfo* fw_output_activation_state_;
|
||
|
RunTimeOperandInfo* fw_output_cell_state_;
|
||
|
RunTimeOperandInfo* bw_output_activation_state_;
|
||
|
RunTimeOperandInfo* bw_output_cell_state_;
|
||
|
};
|
||
|
|
||
|
} // namespace nn
|
||
|
} // namespace android
|
||
|
|
||
|
#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
|