/* * Copyright (C) 2020 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "CpuExecutor.h" #include "OperationsUtils.h" #ifdef NN_INCLUDE_CPU_IMPLEMENTATION #include "QuantUtils.h" #endif // NN_INCLUDE_CPU_IMPLEMENTATION namespace android { namespace nn { namespace qlstm { namespace { // Inputs constexpr uint32_t kNumInputs = 32; constexpr uint32_t kInputTensor = 0; // Input weight tensors of size: [numUnits, inputSize]. constexpr uint32_t kInputToInputWeightsTensor = 1; constexpr uint32_t kInputToForgetWeightsTensor = 2; constexpr uint32_t kInputToCellWeightsTensor = 3; constexpr uint32_t kInputToOutputWeightsTensor = 4; // Recurrent weight tensors of size [numUnits, outputSize]. constexpr uint32_t kRecurrentToInputWeightsTensor = 5; constexpr uint32_t kRecurrentToForgetWeightsTensor = 6; constexpr uint32_t kRecurrentToCellWeightsTensor = 7; constexpr uint32_t kRecurrentToOutputWeightsTensor = 8; // For peephole (optional). // Cell to input/forget/output weights of size [numUnits]. constexpr uint32_t kCellToInputWeightsTensor = 9; constexpr uint32_t kCellToForgetWeightsTensor = 10; constexpr uint32_t kCellToOutputWeightsTensor = 11; // Gates bias tensors of size [numUnits]. constexpr uint32_t kInputGateBiasTensor = 12; constexpr uint32_t kForgetGateBiasTensor = 13; constexpr uint32_t kCellGateBiasTensor = 14; constexpr uint32_t kOutputGateBiasTensor = 15; // Projection weight tensor of size [outputSize, numUnits]. constexpr uint32_t kProjectionWeightsTensor = 16; // Projection bias tensor of size [outputSize]. constexpr uint32_t kProjectionBiasTensor = 17; // Output from the previous time step, as tensor // of size [numBatches, outputSize]. constexpr uint32_t kPrevOutputTensor = 18; // Cell state from the previous time step, as tensor // of size [numBatches, numUnits]. constexpr uint32_t kPrevCellStateTensor = 19; // Layer normalization tensors of size [numUnits]. constexpr uint32_t kInputLayerNormTensor = 20; constexpr uint32_t kForgetLayerNormTensor = 21; constexpr uint32_t kCellLayerNormTensor = 22; constexpr uint32_t kOutputLayerNormTensor = 23; // Clipping. constexpr uint32_t kCellClip = 24; constexpr uint32_t kProjectionClip = 25; // Scales of the result of matmul, i.e. input to layer normalization. constexpr uint32_t kInputIntermediateScale = 26; constexpr uint32_t kForgetIntermediateScale = 27; constexpr uint32_t kCellIntermediateScale = 28; constexpr uint32_t kOutputIntermediateScale = 29; // Zero point and scale of hidden state. constexpr uint32_t kHiddenStateZeroPoint = 30; constexpr uint32_t kHiddenStateScale = 31; // Outputs: constexpr uint32_t kNumOutputs = 3; constexpr uint32_t kOutputStateOutTensor = 0; constexpr uint32_t kCellStateOutTensor = 1; constexpr uint32_t kOutputTensor = 2; inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) { return context->getInputBuffer(tensor) != nullptr; } } // namespace Result validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector inExpectedTypes; // Input. inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED); // Input-to-* and recurrent-to-* weights. for (int i = 0; i < 8; ++i) { inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_SYMM); } // Cell-to-* weights. for (int i = 0; i < 3; ++i) { inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM); } // Gate biases. for (int i = 0; i < 4; ++i) { inExpectedTypes.push_back(OperandType::TENSOR_INT32); } // Projection. inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_SYMM); inExpectedTypes.push_back(OperandType::TENSOR_INT32); // Previous output. inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED); // Previous cell state. inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM); // Layer norm weights for (int i = 0; i < 4; ++i) { inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM); } // Cell/projection clipping and scales of intermediate results at the 4 gates. for (int i = 0; i < 6; ++i) { inExpectedTypes.push_back(OperandType::FLOAT32); } // Zero point and scale of the hidden state. inExpectedTypes.push_back(OperandType::INT32); inExpectedTypes.push_back(OperandType::FLOAT32); NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); std::vector outExpectedTypes; // Output state (out). outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED); // Cell state (out). outExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM); // Output. outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED); NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); return Version::ANDROID_R; } bool prepare(IOperationExecutionContext* context) { // Check that none of the required inputs are omitted const std::vector requiredTensorInputs = { kInputTensor, kInputToForgetWeightsTensor, kInputToCellWeightsTensor, kInputToOutputWeightsTensor, kRecurrentToForgetWeightsTensor, kRecurrentToCellWeightsTensor, kRecurrentToOutputWeightsTensor, kForgetGateBiasTensor, kCellGateBiasTensor, kOutputGateBiasTensor, kPrevOutputTensor, kPrevCellStateTensor, }; for (const int tensor : requiredTensorInputs) { NN_RET_CHECK(!context->isOmittedInput(tensor)) << "required input " << tensor << " is omitted"; } const Shape inputShape = context->getInputShape(kInputTensor); const uint32_t inputRank = getNumberOfDimensions(inputShape); NN_RET_CHECK_EQ(inputRank, 2) << "Invalid input tensor rank: " << inputRank; const uint32_t batchSize = getSizeOfDimension(inputShape, 0); const uint32_t inputSize = getSizeOfDimension(inputShape, 1); const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize); const uint32_t numUnits = getSizeOfDimension(inputToOutputShape, 0); const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numUnits); const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1); if (hasTensor(context, kInputToInputWeightsTensor)) { const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numUnits); NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize); } const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numUnits); NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize); const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numUnits); NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize); if (hasTensor(context, kRecurrentToInputWeightsTensor)) { const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numUnits); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize); } const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numUnits); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize); const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numUnits); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize); // Make sure the input-gate's parameters are either all present (non-CIFG) or // not at all (CIFG). const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) && hasTensor(context, kRecurrentToInputWeightsTensor)) || (!hasTensor(context, kInputToInputWeightsTensor) && !hasTensor(context, kRecurrentToInputWeightsTensor)); NN_RET_CHECK(cifgWeightsAllOrNone); if (hasTensor(context, kCellToInputWeightsTensor)) { const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numUnits); } if (hasTensor(context, kCellToForgetWeightsTensor)) { const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numUnits); } if (hasTensor(context, kCellToOutputWeightsTensor)) { const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numUnits); } // Making sure the peephole weights are there all or none. const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor); const bool peepholeWeightsAllOrNone = ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) && hasTensor(context, kCellToForgetWeightsTensor) && hasTensor(context, kCellToOutputWeightsTensor)) || (!hasTensor(context, kCellToInputWeightsTensor) && !hasTensor(context, kCellToForgetWeightsTensor) && !hasTensor(context, kCellToOutputWeightsTensor)); NN_RET_CHECK(peepholeWeightsAllOrNone); if (!cifgUsed) { NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor)); const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numUnits); } else { NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor)) << "Input gate bias tensor is present when CIFG is used"; } const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numUnits); const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numUnits); const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numUnits); if (hasTensor(context, kProjectionWeightsTensor)) { const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize); NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numUnits); } if (hasTensor(context, kProjectionBiasTensor)) { const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize); } const Shape outputStateShape = context->getInputShape(kPrevOutputTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize); NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize); const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize); NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numUnits); if (hasTensor(context, kInputLayerNormTensor)) { const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numUnits); } if (hasTensor(context, kForgetLayerNormTensor)) { const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numUnits); } if (hasTensor(context, kCellLayerNormTensor)) { const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numUnits); } if (hasTensor(context, kOutputLayerNormTensor)) { const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numUnits); } if (cifgUsed) { NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor)) << "Input layer norm weights tensor is present when CIFG is used"; const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) && hasTensor(context, kCellLayerNormTensor) && hasTensor(context, kOutputLayerNormTensor)) || (!hasTensor(context, kForgetLayerNormTensor) && !hasTensor(context, kCellLayerNormTensor) && !hasTensor(context, kOutputLayerNormTensor)); NN_RET_CHECK(layerNormWeightsAllOrNoneCifg); } else { const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) && hasTensor(context, kForgetLayerNormTensor) && hasTensor(context, kCellLayerNormTensor) && hasTensor(context, kOutputLayerNormTensor)) || (!hasTensor(context, kInputLayerNormTensor) && !hasTensor(context, kForgetLayerNormTensor) && !hasTensor(context, kCellLayerNormTensor) && !hasTensor(context, kOutputLayerNormTensor)); NN_RET_CHECK(layerNormWeightsAllOrNone); } const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor); Shape outputShape = context->getOutputShape(kOutputTensor); outputShape.dimensions = prevOutputShape.dimensions; const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor); Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor); cellStateOutShape.dimensions = prevCellStateShape.dimensions; return context->setOutputShape(kOutputStateOutTensor, outputShape) && context->setOutputShape(kCellStateOutTensor, cellStateOutShape) && context->setOutputShape(kOutputTensor, outputShape); } #ifdef NN_INCLUDE_CPU_IMPLEMENTATION bool execute(IOperationExecutionContext* context) { // Gets the inputs. const Shape inputShape = context->getInputShape(kInputTensor); const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor); const Shape recurrentToInputWeightsShape = context->getInputShape(kRecurrentToInputWeightsTensor); const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor); const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor); const Shape recurrentToForgetWeightsShape = context->getInputShape(kRecurrentToForgetWeightsTensor); const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor); const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor); const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor); const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor); const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor); const Shape recurrentToOutputWeightsShape = context->getInputShape(kRecurrentToOutputWeightsTensor); const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor); const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor); const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor); const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor); const uint32_t batchSize = inputShape.dimensions[0]; const uint32_t inputSize = inputShape.dimensions[1]; const uint32_t numUnits = inputToOutputWeightsShape.dimensions[0]; const uint32_t outputSize = recurrentToOutputWeightsShape.dimensions[1]; const float cellClip = context->getInputValue(kCellClip); const float projectionClip = context->getInputValue(kProjectionClip); const float inputIntermediateScale = context->getInputValue(kInputIntermediateScale); const float forgetIntermediateScale = context->getInputValue(kForgetIntermediateScale); const float cellIntermediateScale = context->getInputValue(kCellIntermediateScale); const float outputIntermediateScale = context->getInputValue(kOutputIntermediateScale); const int8_t hiddenStateZeroPoint = context->getInputValue(kHiddenStateZeroPoint); const float hiddenStateScale = context->getInputValue(kHiddenStateScale); const int8_t* inputBuffer = reinterpret_cast(context->getInputBuffer(kInputTensor)); const int8_t* inputToInputWeightsBuffer = reinterpret_cast(context->getInputBuffer(kInputToInputWeightsTensor)); const bool useCifg = (inputToInputWeightsBuffer == nullptr); const int8_t* recurrentToInputWeightsBuffer = reinterpret_cast( context->getInputBuffer(kRecurrentToInputWeightsTensor)); const int16_t* cellToInputBuffer = reinterpret_cast(context->getInputBuffer(kCellToInputWeightsTensor)); const int16_t* inputLayerNormBuffer = reinterpret_cast(context->getInputBuffer(kInputLayerNormTensor)); const int32_t* inputBiasBuffer = reinterpret_cast(context->getInputBuffer(kInputGateBiasTensor)); const int8_t* inputToForgetWeightsBuffer = reinterpret_cast(context->getInputBuffer(kInputToForgetWeightsTensor)); const int8_t* recurrentToForgetWeightsBuffer = reinterpret_cast( context->getInputBuffer(kRecurrentToForgetWeightsTensor)); const int16_t* cellToForgetBuffer = reinterpret_cast(context->getInputBuffer(kCellToForgetWeightsTensor)); const int16_t* forgetLayerNormBuffer = reinterpret_cast(context->getInputBuffer(kForgetLayerNormTensor)); const int32_t* forgetBiasBuffer = reinterpret_cast(context->getInputBuffer(kForgetGateBiasTensor)); const int8_t* inputToCellWeightsBuffer = reinterpret_cast(context->getInputBuffer(kInputToCellWeightsTensor)); const int8_t* recurrentToCellWeightsBuffer = reinterpret_cast(context->getInputBuffer(kRecurrentToCellWeightsTensor)); const int16_t* cellLayerNormBuffer = reinterpret_cast(context->getInputBuffer(kCellLayerNormTensor)); const int32_t* cellBiasBuffer = reinterpret_cast(context->getInputBuffer(kCellGateBiasTensor)); const int8_t* inputToOutputWeightsBuffer = reinterpret_cast(context->getInputBuffer(kInputToOutputWeightsTensor)); const int8_t* recurrentToOutputWeightsBuffer = reinterpret_cast( context->getInputBuffer(kRecurrentToOutputWeightsTensor)); const int16_t* cellToOutputBuffer = reinterpret_cast(context->getInputBuffer(kCellToOutputWeightsTensor)); const int16_t* outputLayerNormBuffer = reinterpret_cast(context->getInputBuffer(kOutputLayerNormTensor)); const int32_t* outputBiasBuffer = reinterpret_cast(context->getInputBuffer(kOutputGateBiasTensor)); const int8_t* projectionWeightsBuffer = reinterpret_cast(context->getInputBuffer(kProjectionWeightsTensor)); const int32_t* projectionBiasBuffer = reinterpret_cast(context->getInputBuffer(kProjectionBiasTensor)); const int8_t* prevOutputBuffer = reinterpret_cast(context->getInputBuffer(kPrevOutputTensor)); const int16_t* prevCellStateBuffer = reinterpret_cast(context->getInputBuffer(kPrevCellStateTensor)); uint8_t* outputStateBuffer = reinterpret_cast(context->getOutputBuffer(kOutputStateOutTensor)); int16_t* cellStateBuffer = reinterpret_cast(context->getOutputBuffer(kCellStateOutTensor)); int8_t* outputBuffer = reinterpret_cast(context->getOutputBuffer(kOutputTensor)); // Calculates and decomposes effective scales. // This is for optimizing the matmul calculation. int cellShift; NN_RET_CHECK(CheckedLog2(prevCellStateShape.scale, &cellShift)); NN_RET_CHECK(cellShift <= -9); int32_t inputToInputEffectiveScaleA; int32_t inputToInputEffectiveScaleB; int32_t recurrentToInputEffectiveScaleA; int32_t recurrentToInputEffectiveScaleB; int32_t cellToInputEffectiveScaleA; int32_t cellToInputEffectiveScaleB; if (!useCifg) { const float inputToInputEffectiveScale = inputToInputWeightsShape.scale * inputShape.scale / inputIntermediateScale; NN_RET_CHECK(QuantizeMultiplier(inputToInputEffectiveScale, &inputToInputEffectiveScaleA, &inputToInputEffectiveScaleB)); const float recurrentToInputEffectiveScale = recurrentToInputWeightsShape.scale * prevOutputShape.scale / inputIntermediateScale; NN_RET_CHECK(QuantizeMultiplier(recurrentToInputEffectiveScale, &recurrentToInputEffectiveScaleA, &recurrentToInputEffectiveScaleB)); if (cellToInputBuffer != nullptr) { const float cellToInputEffectiveScale = std::pow(2, cellShift) * cellToInputShape.scale / inputIntermediateScale; NN_RET_CHECK(QuantizeMultiplier(cellToInputEffectiveScale, &cellToInputEffectiveScaleA, &cellToInputEffectiveScaleB)); } } int32_t inputLayerNormScaleA; int32_t inputLayerNormScaleB; if (inputLayerNormBuffer != nullptr) { NN_RET_CHECK(QuantizeMultiplier(inputLayerNormShape.scale, &inputLayerNormScaleA, &inputLayerNormScaleB)); } const float inputToForgetEffectiveScale = inputToForgetWeightsShape.scale * inputShape.scale / forgetIntermediateScale; int32_t inputToForgetEffectiveScaleA; int32_t inputToForgetEffectiveScaleB; NN_RET_CHECK(QuantizeMultiplier(inputToForgetEffectiveScale, &inputToForgetEffectiveScaleA, &inputToForgetEffectiveScaleB)); const float recurrentToForgetEffectiveScale = recurrentToForgetWeightsShape.scale * prevOutputShape.scale / forgetIntermediateScale; int32_t recurrentToForgetEffectiveScaleA; int32_t recurrentToForgetEffectiveScaleB; NN_RET_CHECK(QuantizeMultiplier(recurrentToForgetEffectiveScale, &recurrentToForgetEffectiveScaleA, &recurrentToForgetEffectiveScaleB)); int32_t cellToForgetEffectiveScaleA; int32_t cellToForgetEffectiveScaleB; if (cellToForgetBuffer != nullptr) { const float cellToForgetEffectiveScale = std::pow(2, cellShift) * cellToForgetShape.scale / forgetIntermediateScale; NN_RET_CHECK(QuantizeMultiplier(cellToForgetEffectiveScale, &cellToForgetEffectiveScaleA, &cellToForgetEffectiveScaleB)); } int32_t forgetLayerNormScaleA; int32_t forgetLayerNormScaleB; if (forgetLayerNormBuffer != nullptr) { NN_RET_CHECK(QuantizeMultiplier(forgetLayerNormShape.scale, &forgetLayerNormScaleA, &forgetLayerNormScaleB)); } const float inputToCellEffectiveScale = inputToCellWeightsShape.scale * inputShape.scale / cellIntermediateScale; int32_t inputToCellEffectiveScaleA; int32_t inputToCellEffectiveScaleB; NN_RET_CHECK(QuantizeMultiplier(inputToCellEffectiveScale, &inputToCellEffectiveScaleA, &inputToCellEffectiveScaleB)); const float recurrentToCellEffectiveScale = recurrentToCellWeightsShape.scale * prevOutputShape.scale / cellIntermediateScale; int32_t recurrentToCellEffectiveScaleA; int32_t recurrentToCellEffectiveScaleB; NN_RET_CHECK(QuantizeMultiplier(recurrentToCellEffectiveScale, &recurrentToCellEffectiveScaleA, &recurrentToCellEffectiveScaleB)); int32_t cellLayerNormScaleA; int32_t cellLayerNormScaleB; if (cellLayerNormBuffer != nullptr) { NN_RET_CHECK(QuantizeMultiplier(cellLayerNormShape.scale, &cellLayerNormScaleA, &cellLayerNormScaleB)); } const float inputToOutputEffectiveScale = inputToOutputWeightsShape.scale * inputShape.scale / outputIntermediateScale; int32_t inputToOutputEffectiveScaleA; int32_t inputToOutputEffectiveScaleB; NN_RET_CHECK(QuantizeMultiplier(inputToOutputEffectiveScale, &inputToOutputEffectiveScaleA, &inputToOutputEffectiveScaleB)); const float recurrentToOutputEffectiveScale = recurrentToOutputWeightsShape.scale * prevOutputShape.scale / outputIntermediateScale; int32_t recurrentToOutputEffectiveScaleA; int32_t recurrentToOutputEffectiveScaleB; NN_RET_CHECK(QuantizeMultiplier(recurrentToOutputEffectiveScale, &recurrentToOutputEffectiveScaleA, &recurrentToOutputEffectiveScaleB)); int32_t cellToOutputEffectiveScaleA; int32_t cellToOutputEffectiveScaleB; if (cellToOutputBuffer != nullptr) { const float cellToOutputEffectiveScale = std::pow(2, cellShift) * cellToOutputShape.scale / outputIntermediateScale; NN_RET_CHECK(QuantizeMultiplier(cellToOutputEffectiveScale, &cellToOutputEffectiveScaleA, &cellToOutputEffectiveScaleB)); } int32_t outputLayerNormScaleA; int32_t outputLayerNormScaleB; if (outputLayerNormBuffer != nullptr) { NN_RET_CHECK(QuantizeMultiplier(outputLayerNormShape.scale, &outputLayerNormScaleA, &outputLayerNormScaleB)); } const float hiddenStateEffectiveScale = std::pow(2, -15) / hiddenStateScale * std::pow(2, -15); int32_t hiddenStateEffectiveScaleA; int32_t hiddenStateEffectiveScaleB; NN_RET_CHECK(QuantizeMultiplier(hiddenStateEffectiveScale, &hiddenStateEffectiveScaleA, &hiddenStateEffectiveScaleB)); int32_t projectionEffectiveScaleA; int32_t projectionEffectiveScaleB; if (projectionWeightsBuffer != nullptr) { const float projectionEffectiveScale = projectionWeightsShape.scale * hiddenStateScale / prevOutputShape.scale; NN_RET_CHECK(QuantizeMultiplier(projectionEffectiveScale, &projectionEffectiveScaleA, &projectionEffectiveScaleB)); } // Calculates quantized clipping parameters. int16_t quantizedCellClip = 0; if (cellClip > 0.0) { quantizedCellClip = static_cast( std::min(std::max(cellClip / prevCellStateShape.scale, -32768.0f), 32767.0f)); } int8_t quantizedProjectionClip = 0; if (projectionClip > 0.0) { quantizedProjectionClip = static_cast( std::min(std::max(projectionClip / projectionWeightsShape.scale, -128.0f), 127.0f)); } // Calculates effective bias. // This is for optimizing the matmul calculation. std::unique_ptr inputToInputEffectiveBias; std::unique_ptr recurrentToInputEffectiveBias; if (!useCifg) { NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( -inputShape.offset, inputToInputWeightsBuffer, inputToInputWeightsShape, /*bias=*/nullptr, &inputToInputEffectiveBias)); NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( -prevOutputShape.offset, recurrentToInputWeightsBuffer, recurrentToInputWeightsShape, /*bias=*/nullptr, &recurrentToInputEffectiveBias)); } std::unique_ptr inputToForgetEffectiveBias; NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( -inputShape.offset, inputToForgetWeightsBuffer, inputToForgetWeightsShape, /*bias=*/nullptr, &inputToForgetEffectiveBias)); std::unique_ptr recurrentToForgetEffectiveBias; NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( -prevOutputShape.offset, recurrentToForgetWeightsBuffer, recurrentToForgetWeightsShape, /*bias=*/nullptr, &recurrentToForgetEffectiveBias)); std::unique_ptr inputToCellEffectiveBias; NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( -inputShape.offset, inputToCellWeightsBuffer, inputToCellWeightsShape, /*bias=*/nullptr, &inputToCellEffectiveBias)); std::unique_ptr recurrentToCellEffectiveBias; NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( -prevOutputShape.offset, recurrentToCellWeightsBuffer, recurrentToCellWeightsShape, /*bias=*/nullptr, &recurrentToCellEffectiveBias)); std::unique_ptr inputToOutputEffectiveBias; NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( -inputShape.offset, inputToOutputWeightsBuffer, inputToOutputWeightsShape, /*bias=*/nullptr, &inputToOutputEffectiveBias)); std::unique_ptr recurrentToOutputEffectiveBias; NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( -prevOutputShape.offset, recurrentToOutputWeightsBuffer, recurrentToOutputWeightsShape, /*bias=*/nullptr, &recurrentToOutputEffectiveBias)); std::unique_ptr projectionEffectiveBias; if (projectionBiasBuffer != nullptr) { NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias( hiddenStateZeroPoint, projectionWeightsBuffer, projectionWeightsShape, projectionBiasBuffer, &projectionEffectiveBias)); } // Temporary buffers. std::vector inputGateBuffer(batchSize * numUnits); std::vector forgetGateBuffer(batchSize * numUnits); std::vector cellGateBuffer(batchSize * numUnits); std::vector outputGateBuffer(batchSize * numUnits); std::vector buffer8(batchSize * numUnits); // To avoid overflow when calculating layer norm. const int32_t inputInvLargeValue = std::min(1, static_cast(10000 * inputLayerNormShape.scale)); const int32_t forgetInvLargeValue = std::min(1, static_cast(10000 * forgetLayerNormShape.scale)); const int32_t cellInvLargeValue = std::min(1, static_cast(10000 * cellLayerNormShape.scale)); const int32_t outputInvLargeValue = std::min(1, static_cast(10000 * outputLayerNormShape.scale)); // Forget gate. MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToForgetEffectiveBias.get(), inputToForgetWeightsBuffer, inputToForgetEffectiveScaleA, inputToForgetEffectiveScaleB, batchSize, inputSize, numUnits, /*outputZeroPoint=*/0, forgetGateBuffer.data()); MatrixBatchVectorMultiplyAccumulate( prevOutputBuffer, recurrentToForgetEffectiveBias.get(), recurrentToForgetWeightsBuffer, recurrentToForgetEffectiveScaleA, recurrentToForgetEffectiveScaleB, batchSize, outputSize, numUnits, /*outputZeroPoint=*/0, forgetGateBuffer.data()); if (cellToForgetBuffer != nullptr) { VectorBatchVectorCwiseProductAccumulate( cellToForgetBuffer, outputSize, cellStateBuffer, batchSize, cellToForgetEffectiveScaleA, cellToForgetEffectiveScaleB, forgetGateBuffer.data()); } if (forgetLayerNormBuffer != nullptr) { ApplyLayerNorm(forgetGateBuffer.data(), forgetLayerNormBuffer, forgetBiasBuffer, forgetLayerNormScaleA, forgetLayerNormScaleB, forgetInvLargeValue, batchSize, numUnits, forgetGateBuffer.data()); } ApplySigmoid(forgetGateBuffer.data(), batchSize, numUnits, forgetGateBuffer.data()); // Modulation gate. MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToCellEffectiveBias.get(), inputToCellWeightsBuffer, inputToCellEffectiveScaleA, inputToCellEffectiveScaleB, batchSize, inputSize, numUnits, /*outputZeroPoint=*/0, cellGateBuffer.data()); MatrixBatchVectorMultiplyAccumulate( prevOutputBuffer, recurrentToCellEffectiveBias.get(), recurrentToCellWeightsBuffer, recurrentToCellEffectiveScaleA, recurrentToCellEffectiveScaleB, batchSize, outputSize, numUnits, /*outputZeroPoint=*/0, cellGateBuffer.data()); if (cellLayerNormBuffer != nullptr) { ApplyLayerNorm(cellGateBuffer.data(), cellLayerNormBuffer, cellBiasBuffer, cellLayerNormScaleA, cellLayerNormScaleB, cellInvLargeValue, batchSize, numUnits, cellGateBuffer.data()); } ApplyTanh<3>(cellGateBuffer.data(), batchSize, numUnits, cellGateBuffer.data()); // Input gate. if (useCifg) { Sub1Vector(forgetGateBuffer.data(), batchSize * numUnits, inputGateBuffer.data()); } else { MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToInputEffectiveBias.get(), inputToInputWeightsBuffer, inputToInputEffectiveScaleA, inputToInputEffectiveScaleB, batchSize, inputSize, numUnits, /*outputZeroPoint=*/0, inputGateBuffer.data()); MatrixBatchVectorMultiplyAccumulate( prevOutputBuffer, recurrentToInputEffectiveBias.get(), recurrentToInputWeightsBuffer, recurrentToInputEffectiveScaleA, recurrentToInputEffectiveScaleB, batchSize, outputSize, numUnits, /*outputZeroPoint=*/0, inputGateBuffer.data()); if (cellToInputBuffer != nullptr) { VectorBatchVectorCwiseProductAccumulate( cellToInputBuffer, outputSize, cellStateBuffer, batchSize, cellToInputEffectiveScaleA, cellToInputEffectiveScaleB, inputGateBuffer.data()); } if (inputLayerNormBuffer != nullptr) { ApplyLayerNorm(inputGateBuffer.data(), inputLayerNormBuffer, inputBiasBuffer, inputLayerNormScaleA, inputLayerNormScaleB, inputInvLargeValue, batchSize, numUnits, inputGateBuffer.data()); } ApplySigmoid(inputGateBuffer.data(), batchSize, numUnits, inputGateBuffer.data()); } // Cell. CwiseMul(forgetGateBuffer.data(), prevCellStateBuffer, batchSize, numUnits, /*shift=*/15, forgetGateBuffer.data()); CwiseMul(inputGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, 30 + cellShift, cellGateBuffer.data()); CwiseAdd(forgetGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, cellStateBuffer); if (quantizedCellClip > 0) { CwiseClipping(cellStateBuffer, quantizedCellClip, batchSize, numUnits); } // Output gate. MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToOutputEffectiveBias.get(), inputToOutputWeightsBuffer, inputToOutputEffectiveScaleA, inputToOutputEffectiveScaleB, batchSize, inputSize, numUnits, /*outputZeroPoint=*/0, outputGateBuffer.data()); MatrixBatchVectorMultiplyAccumulate( prevOutputBuffer, recurrentToOutputEffectiveBias.get(), recurrentToOutputWeightsBuffer, recurrentToOutputEffectiveScaleA, recurrentToOutputEffectiveScaleB, batchSize, outputSize, numUnits, /*outputZeroPoint=*/0, outputGateBuffer.data()); if (cellToOutputBuffer != nullptr) { VectorBatchVectorCwiseProductAccumulate( cellToOutputBuffer, outputSize, cellStateBuffer, batchSize, cellToOutputEffectiveScaleA, cellToOutputEffectiveScaleB, outputGateBuffer.data()); } if (outputLayerNormBuffer != nullptr) { ApplyLayerNorm(outputGateBuffer.data(), outputLayerNormBuffer, outputBiasBuffer, outputLayerNormScaleA, outputLayerNormScaleB, outputInvLargeValue, batchSize, numUnits, outputGateBuffer.data()); } ApplySigmoid(outputGateBuffer.data(), batchSize, numUnits, outputGateBuffer.data()); // Hidden. ApplyTanh(cellShift + 15, cellStateBuffer, batchSize, numUnits, inputGateBuffer.data()); CwiseMul(outputGateBuffer.data(), inputGateBuffer.data(), hiddenStateEffectiveScaleA, hiddenStateEffectiveScaleB, batchSize, numUnits, hiddenStateZeroPoint, buffer8.data()); // Projection. if (projectionWeightsBuffer != nullptr) { memset(outputBuffer, 0, batchSize * outputSize * sizeof(int8_t)); MatrixBatchVectorMultiplyAccumulate(buffer8.data(), projectionEffectiveBias.get(), projectionWeightsBuffer, projectionEffectiveScaleA, projectionEffectiveScaleB, batchSize, numUnits, outputSize, prevOutputShape.offset, outputBuffer); if (quantizedProjectionClip > 0) { CwiseClipping(outputBuffer, quantizedProjectionClip, batchSize, outputSize); } } else { std::copy_n(buffer8.data(), batchSize * outputSize, outputBuffer); } // Copy output to output state out. for (unsigned int i = 0; i < batchSize * outputSize; ++i) { outputStateBuffer[i] = outputBuffer[i]; } return true; } #endif // NN_INCLUDE_CPU_IMPLEMENTATION } // namespace qlstm NN_REGISTER_OPERATION(QUANTIZED_LSTM, "QUANTIZED_LSTM", qlstm::validate, qlstm::prepare, qlstm::execute, .allowOmittedOperand = true); } // namespace nn } // namespace android