You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
531 lines
28 KiB
531 lines
28 KiB
/*
|
|
* Copyright (C) 2019 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#define LOG_TAG "Operations"
|
|
|
|
#include <vector>
|
|
|
|
#include "IndexedShapeWrapper.h"
|
|
#include "OperationResolver.h"
|
|
#include "OperationsUtils.h"
|
|
|
|
#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
|
|
#include "LSTM.h"
|
|
#endif // NN_INCLUDE_CPU_IMPLEMENTATION
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
namespace unidirectional_sequence_lstm {
|
|
|
|
// Inputs
|
|
constexpr uint32_t kNumInputs = 28;
|
|
|
|
// Input tensor of size {max_time, n_batch, n_input}
|
|
constexpr uint32_t kInputTensor = 0;
|
|
|
|
// Input weight tensors of size: {n_cell, n_input}
|
|
constexpr uint32_t kInputToInputWeightsTensor = 1; // Optional
|
|
constexpr uint32_t kInputToForgetWeightsTensor = 2;
|
|
constexpr uint32_t kInputToCellWeightsTensor = 3;
|
|
constexpr uint32_t kInputToOutputWeightsTensor = 4;
|
|
|
|
// Recurrent weight tensors of size {n_cell, n_output}
|
|
constexpr uint32_t kRecurrentToInputWeightsTensor = 5; // Optional
|
|
constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
|
|
constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
|
|
constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
|
|
|
|
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
|
|
constexpr uint32_t kCellToInputWeightsTensor = 9; // Optional
|
|
constexpr uint32_t kCellToForgetWeightsTensor = 10; // Optional
|
|
constexpr uint32_t kCellToOutputWeightsTensor = 11; // Optional
|
|
|
|
// Gates bias tensors of size {n_cell}
|
|
constexpr uint32_t kInputGateBiasTensor = 12; // Optional
|
|
constexpr uint32_t kForgetGateBiasTensor = 13;
|
|
constexpr uint32_t kCellGateBiasTensor = 14;
|
|
constexpr uint32_t kOutputGateBiasTensor = 15;
|
|
|
|
// Projection weight tensor of size {n_output, n_cell}
|
|
constexpr uint32_t kProjectionWeightsTensor = 16; // Optional
|
|
// Projection bias tensor of size {n_output}
|
|
constexpr uint32_t kProjectionBiasTensor = 17; // Optional
|
|
|
|
// Input from the output of the previous step, tensor of size {batch_size, n_output}
|
|
constexpr uint32_t kOutputStateInTensor = 18;
|
|
// Input from the cell state of the previous step, tensor of size {batch_size, n_cell}
|
|
constexpr uint32_t kCellStateInTensor = 19;
|
|
|
|
constexpr uint32_t kActivationParam = 20;
|
|
constexpr uint32_t kCellClipParam = 21;
|
|
constexpr uint32_t kProjClipParam = 22;
|
|
constexpr uint32_t kTimeMajorParam = 23;
|
|
|
|
// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
|
|
constexpr uint32_t kInputLayerNormWeightsTensor = 24; // Optional
|
|
constexpr uint32_t kForgetLayerNormWeightsTensor = 25; // Optional
|
|
constexpr uint32_t kCellLayerNormWeightsTensor = 26; // Optional
|
|
constexpr uint32_t kOutputLayerNormWeightsTensor = 27; // Optional
|
|
|
|
// Output tensors.
|
|
constexpr uint32_t kNumOutputs = 1;
|
|
constexpr uint32_t kNumOutputsWithState = 3;
|
|
|
|
constexpr uint32_t kOutputTensor = 0;
|
|
constexpr uint32_t kOutputStateOutTensor = 1;
|
|
constexpr uint32_t kCellStateOutTensor = 2;
|
|
|
|
#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
|
|
namespace {
|
|
|
|
inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
|
|
return context->getInputBuffer(tensor) != nullptr;
|
|
}
|
|
|
|
inline bool isTimeMajor(IOperationExecutionContext* context) {
|
|
return context->getInputValue<bool>(kTimeMajorParam);
|
|
}
|
|
|
|
template <typename T>
|
|
inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
|
|
LSTMParams params;
|
|
params.activation =
|
|
static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam));
|
|
params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
|
|
params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
|
|
params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
|
|
params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
|
|
params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
|
|
params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
|
|
params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
|
|
return params;
|
|
}
|
|
|
|
} // namespace
|
|
#endif // NN_INCLUDE_CPU_IMPLEMENTATION
|
|
|
|
Result<Version> validate(const IOperationValidationContext* context) {
|
|
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
|
|
const uint32_t numOutputs = context->getNumOutputs();
|
|
NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
|
|
const OperandType inputType = context->getInputType(kInputTensor);
|
|
std::vector<OperandType> inExpectedTypes;
|
|
std::vector<OperandType> outExpectedTypes;
|
|
if (inputType == OperandType::TENSOR_FLOAT32) {
|
|
inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::INT32, OperandType::FLOAT32,
|
|
OperandType::FLOAT32, OperandType::BOOL,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
|
|
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32};
|
|
outExpectedTypes = {OperandType::TENSOR_FLOAT32};
|
|
} else if (inputType == OperandType::TENSOR_FLOAT16) {
|
|
inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::INT32, OperandType::FLOAT16,
|
|
OperandType::FLOAT16, OperandType::BOOL,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
|
|
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16};
|
|
outExpectedTypes = {OperandType::TENSOR_FLOAT16};
|
|
} else {
|
|
NN_RET_CHECK_FAIL()
|
|
<< "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: "
|
|
<< inputType;
|
|
}
|
|
Version minVersionSupported = Version::ANDROID_Q;
|
|
if (context->getNumOutputs() == kNumOutputsWithState) {
|
|
minVersionSupported = Version::ANDROID_R;
|
|
outExpectedTypes.insert(outExpectedTypes.end(), {inputType, inputType});
|
|
}
|
|
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
|
|
NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
|
|
return minVersionSupported;
|
|
}
|
|
|
|
#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
|
|
bool prepare(IOperationExecutionContext* context) {
|
|
// Check that none of the required inputs are omitted
|
|
const std::vector<int> requiredInputs = {
|
|
kInputTensor,
|
|
kInputToForgetWeightsTensor,
|
|
kInputToCellWeightsTensor,
|
|
kInputToOutputWeightsTensor,
|
|
kRecurrentToForgetWeightsTensor,
|
|
kRecurrentToCellWeightsTensor,
|
|
kRecurrentToOutputWeightsTensor,
|
|
kForgetGateBiasTensor,
|
|
kCellGateBiasTensor,
|
|
kOutputGateBiasTensor,
|
|
kOutputStateInTensor,
|
|
kCellStateInTensor,
|
|
kActivationParam,
|
|
kCellClipParam,
|
|
kProjClipParam,
|
|
kTimeMajorParam,
|
|
};
|
|
for (const int requiredInput : requiredInputs) {
|
|
NN_RET_CHECK(!context->isOmittedInput(requiredInput))
|
|
<< "required input " << requiredInput << " is omitted";
|
|
}
|
|
|
|
const Shape inputShape = context->getInputShape(kInputTensor);
|
|
const uint32_t inputRank = getNumberOfDimensions(inputShape);
|
|
NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank;
|
|
|
|
const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
|
|
const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
|
|
const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
|
|
|
|
const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
|
|
const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
|
|
|
|
const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
|
|
const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
|
|
|
|
if (hasTensor(context, kInputToInputWeightsTensor)) {
|
|
const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
|
|
}
|
|
|
|
const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
|
|
const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
|
|
|
|
if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
|
|
const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
|
|
}
|
|
|
|
const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
|
|
const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
|
|
|
|
// We make sure the input-gate's parameters are either both present (regular
|
|
// LSTM) or not at all (CIFG-LSTM).
|
|
const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
|
|
hasTensor(context, kRecurrentToInputWeightsTensor)) ||
|
|
(!hasTensor(context, kInputToInputWeightsTensor) &&
|
|
!hasTensor(context, kRecurrentToInputWeightsTensor));
|
|
NN_RET_CHECK(cifgWeightsAllOrNone);
|
|
|
|
if (hasTensor(context, kCellToInputWeightsTensor)) {
|
|
const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
|
|
}
|
|
|
|
if (hasTensor(context, kCellToForgetWeightsTensor)) {
|
|
const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
|
|
}
|
|
|
|
if (hasTensor(context, kCellToOutputWeightsTensor)) {
|
|
const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
|
|
}
|
|
|
|
// Making sure the peephole weights are there all or none.
|
|
const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
|
|
const bool peepholeWeightsAllOrNone =
|
|
((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
|
|
hasTensor(context, kCellToForgetWeightsTensor) &&
|
|
hasTensor(context, kCellToOutputWeightsTensor)) ||
|
|
(!hasTensor(context, kCellToInputWeightsTensor) &&
|
|
!hasTensor(context, kCellToForgetWeightsTensor) &&
|
|
!hasTensor(context, kCellToOutputWeightsTensor));
|
|
NN_RET_CHECK(peepholeWeightsAllOrNone);
|
|
|
|
if (!cifgUsed) {
|
|
NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
|
|
const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
|
|
} else {
|
|
NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
|
|
<< "Input gate bias tensor is present when CIFG is used";
|
|
}
|
|
|
|
const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
|
|
const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
|
|
const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
|
|
|
|
if (hasTensor(context, kProjectionWeightsTensor)) {
|
|
const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
|
|
}
|
|
|
|
if (hasTensor(context, kProjectionBiasTensor)) {
|
|
const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
|
|
}
|
|
|
|
const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
|
|
const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
|
|
|
|
if (hasTensor(context, kInputLayerNormWeightsTensor)) {
|
|
const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
|
|
}
|
|
|
|
if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
|
|
const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
|
|
}
|
|
|
|
if (hasTensor(context, kCellLayerNormWeightsTensor)) {
|
|
const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
|
|
}
|
|
|
|
if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
|
|
const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
|
|
NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
|
|
NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
|
|
}
|
|
|
|
if (cifgUsed) {
|
|
NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
|
|
<< "Input layer norm weights tensor is present when CIFG is used";
|
|
const bool layerNormWeightsAllOrNoneCifg =
|
|
(hasTensor(context, kForgetLayerNormWeightsTensor) &&
|
|
hasTensor(context, kCellLayerNormWeightsTensor) &&
|
|
hasTensor(context, kOutputLayerNormWeightsTensor)) ||
|
|
(!hasTensor(context, kForgetLayerNormWeightsTensor) &&
|
|
!hasTensor(context, kCellLayerNormWeightsTensor) &&
|
|
!hasTensor(context, kOutputLayerNormWeightsTensor));
|
|
NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
|
|
} else {
|
|
const bool layerNormWeightsAllOrNone =
|
|
(hasTensor(context, kInputLayerNormWeightsTensor) &&
|
|
hasTensor(context, kForgetLayerNormWeightsTensor) &&
|
|
hasTensor(context, kCellLayerNormWeightsTensor) &&
|
|
hasTensor(context, kOutputLayerNormWeightsTensor)) ||
|
|
(!hasTensor(context, kInputLayerNormWeightsTensor) &&
|
|
!hasTensor(context, kForgetLayerNormWeightsTensor) &&
|
|
!hasTensor(context, kCellLayerNormWeightsTensor) &&
|
|
!hasTensor(context, kOutputLayerNormWeightsTensor));
|
|
NN_RET_CHECK(layerNormWeightsAllOrNone);
|
|
}
|
|
|
|
Shape outputShape = context->getInputShape(kInputTensor);
|
|
outputShape.dimensions[2] = outputSize;
|
|
|
|
if (context->getNumOutputs() == kNumOutputsWithState) {
|
|
NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor));
|
|
NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor));
|
|
|
|
Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor);
|
|
outputStateOutTensor.dimensions.resize(2);
|
|
outputStateOutTensor.dimensions[0] = batchSize;
|
|
outputStateOutTensor.dimensions[1] = outputSize;
|
|
NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor));
|
|
|
|
Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor);
|
|
cellStateOutTensor.dimensions.resize(2);
|
|
cellStateOutTensor.dimensions[0] = batchSize;
|
|
cellStateOutTensor.dimensions[1] = numCells;
|
|
NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor));
|
|
}
|
|
|
|
return context->setOutputShape(kOutputTensor, outputShape);
|
|
}
|
|
|
|
bool execute(IOperationExecutionContext* context) {
|
|
const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
|
|
const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
|
|
const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
|
|
const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
|
|
const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState);
|
|
|
|
const OperandType inputType = context->getInputType(kInputTensor);
|
|
switch (inputType) {
|
|
case OperandType::TENSOR_FLOAT32: {
|
|
// Initialize empty vectors and resize below only if needed
|
|
std::vector<float> outputStateOutBuffer;
|
|
std::vector<float> cellStateOutBuffer;
|
|
float* outputStateOut;
|
|
float* cellStateOut;
|
|
if (useStateOutTensors) {
|
|
outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor);
|
|
cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor);
|
|
} else {
|
|
outputStateOutBuffer.resize(outputStateSize);
|
|
cellStateOutBuffer.resize(cellStateSize);
|
|
outputStateOut = outputStateOutBuffer.data();
|
|
cellStateOut = cellStateOutBuffer.data();
|
|
}
|
|
std::vector<float> scratchBuffer(scratchSize);
|
|
LSTMCell::LSTMEvalFloat32(
|
|
getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
|
|
context->getInputShape(kInputTensor),
|
|
context->getInputBuffer<float>(kInputToInputWeightsTensor),
|
|
context->getInputBuffer<float>(kInputToForgetWeightsTensor),
|
|
context->getInputBuffer<float>(kInputToCellWeightsTensor),
|
|
context->getInputBuffer<float>(kInputToOutputWeightsTensor),
|
|
context->getInputShape(kInputToOutputWeightsTensor),
|
|
context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
|
|
context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
|
|
context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
|
|
context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
|
|
context->getInputShape(kRecurrentToOutputWeightsTensor),
|
|
context->getInputBuffer<float>(kCellToInputWeightsTensor),
|
|
context->getInputBuffer<float>(kCellToForgetWeightsTensor),
|
|
context->getInputBuffer<float>(kCellToOutputWeightsTensor),
|
|
/*aux_input_buffer=*/nullptr,
|
|
/*aux_input_to_input_weights_buffer=*/nullptr,
|
|
/*aux_input_to_forget_weights_buffer=*/nullptr,
|
|
/*aux_input_to_cell_weights_buffer=*/nullptr,
|
|
/*aux_input_to_output_weights_buffer=*/nullptr,
|
|
context->getInputBuffer<float>(kInputGateBiasTensor),
|
|
context->getInputBuffer<float>(kForgetGateBiasTensor),
|
|
context->getInputBuffer<float>(kCellGateBiasTensor),
|
|
context->getInputBuffer<float>(kOutputGateBiasTensor),
|
|
context->getInputBuffer<float>(kProjectionWeightsTensor),
|
|
context->getInputBuffer<float>(kProjectionBiasTensor),
|
|
context->getInputBuffer<float>(kOutputStateInTensor),
|
|
context->getInputBuffer<float>(kCellStateInTensor),
|
|
context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
|
|
context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
|
|
context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
|
|
context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut,
|
|
cellStateOut, context->getOutputBuffer<float>(kOutputTensor),
|
|
scratchBuffer.data(), isTimeMajor(context));
|
|
} break;
|
|
case OperandType::TENSOR_FLOAT16: {
|
|
// Initialize empty vectors and resize below only if needed
|
|
std::vector<_Float16> outputStateOutBuffer;
|
|
std::vector<_Float16> cellStateOutBuffer;
|
|
_Float16* outputStateOut;
|
|
_Float16* cellStateOut;
|
|
if (useStateOutTensors) {
|
|
outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor);
|
|
cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor);
|
|
} else {
|
|
outputStateOutBuffer.resize(outputStateSize);
|
|
cellStateOutBuffer.resize(cellStateSize);
|
|
outputStateOut = outputStateOutBuffer.data();
|
|
cellStateOut = cellStateOutBuffer.data();
|
|
}
|
|
std::vector<_Float16> scratchBuffer(scratchSize);
|
|
LSTMCell::LSTMEvalFloat16(
|
|
getLSTMParams<_Float16>(context),
|
|
context->getInputBuffer<_Float16>(kInputTensor),
|
|
context->getInputShape(kInputTensor),
|
|
context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
|
|
context->getInputShape(kInputToOutputWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
|
|
context->getInputShape(kRecurrentToOutputWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
|
|
/*aux_input_buffer=*/nullptr,
|
|
/*aux_input_to_input_weights_buffer=*/nullptr,
|
|
/*aux_input_to_forget_weights_buffer=*/nullptr,
|
|
/*aux_input_to_cell_weights_buffer=*/nullptr,
|
|
/*aux_input_to_output_weights_buffer=*/nullptr,
|
|
context->getInputBuffer<_Float16>(kInputGateBiasTensor),
|
|
context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
|
|
context->getInputBuffer<_Float16>(kCellGateBiasTensor),
|
|
context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
|
|
context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kProjectionBiasTensor),
|
|
context->getInputBuffer<_Float16>(kOutputStateInTensor),
|
|
context->getInputBuffer<_Float16>(kCellStateInTensor),
|
|
context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
|
|
context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
|
|
outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor),
|
|
scratchBuffer.data(), isTimeMajor(context));
|
|
} break;
|
|
default: {
|
|
LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
#endif // NN_INCLUDE_CPU_IMPLEMENTATION
|
|
|
|
} // namespace unidirectional_sequence_lstm
|
|
|
|
NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM",
|
|
unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare,
|
|
unidirectional_sequence_lstm::execute, .allowOmittedOperand = true);
|
|
|
|
} // namespace nn
|
|
} // namespace android
|