You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
257 lines
12 KiB
257 lines
12 KiB
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#define LOG_TAG "Operations"
|
|
|
|
#include "RNN.h"
|
|
|
|
#include <vector>
|
|
|
|
#include "CpuExecutor.h"
|
|
#include "CpuOperationUtils.h"
|
|
#include "Tracing.h"
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
|
|
RNN::RNN(const Operation& operation, RunTimeOperandInfo* operands) {
|
|
NNTRACE_TRANS("RNN::RNN");
|
|
input_ = GetInput(operation, operands, kInputTensor);
|
|
weights_ = GetInput(operation, operands, kWeightsTensor);
|
|
recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
|
|
hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
|
|
bias_ = GetInput(operation, operands, kBiasTensor);
|
|
|
|
activation_ = static_cast<ActivationFn>(
|
|
getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
|
|
|
|
hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
|
|
output_ = GetOutput(operation, operands, kOutputTensor);
|
|
}
|
|
|
|
bool RNN::Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* hiddenStateShape,
|
|
Shape* outputShape) {
|
|
NNTRACE_TRANS("RNN::Prepare");
|
|
// Check we have all the inputs and outputs we need.
|
|
const int num_inputs = NumInputsWithValues(operation, operands);
|
|
NN_CHECK(num_inputs == 6);
|
|
NN_CHECK_EQ(NumOutputs(operation), 2);
|
|
|
|
const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor);
|
|
const RunTimeOperandInfo* input_weights = GetInput(operation, operands, kWeightsTensor);
|
|
const RunTimeOperandInfo* recurrent_weights =
|
|
GetInput(operation, operands, kRecurrentWeightsTensor);
|
|
const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor);
|
|
|
|
// Check all the parameters of tensor match within themselves and match the
|
|
// input configuration.
|
|
const uint32_t batch_size = SizeOfDimension(input, 0);
|
|
const uint32_t num_units = SizeOfDimension(input_weights, 0);
|
|
NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
|
|
NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
|
|
NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
|
|
NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
|
|
|
|
const Shape& inputShape = input->shape();
|
|
|
|
// Resize state.
|
|
hiddenStateShape->type = inputShape.type;
|
|
hiddenStateShape->dimensions = {batch_size, num_units};
|
|
|
|
// Resize output.
|
|
outputShape->type = inputShape.type;
|
|
outputShape->dimensions = {batch_size, num_units};
|
|
|
|
return true;
|
|
}
|
|
|
|
bool RNN::Eval() {
|
|
switch (input_->type) {
|
|
case OperandType::TENSOR_FLOAT16: {
|
|
RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(),
|
|
reinterpret_cast<_Float16*>(hidden_state_in_->buffer),
|
|
reinterpret_cast<_Float16*>(bias_->buffer),
|
|
reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(),
|
|
reinterpret_cast<_Float16*>(recurrent_weights_->buffer),
|
|
recurrent_weights_->shape(), activation_,
|
|
reinterpret_cast<_Float16*>(output_->buffer));
|
|
memcpy(hidden_state_out_->buffer, output_->buffer,
|
|
sizeof(_Float16) * getNumberOfElements(output_->shape()));
|
|
break;
|
|
}
|
|
case OperandType::TENSOR_FLOAT32: {
|
|
RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(),
|
|
reinterpret_cast<float*>(hidden_state_in_->buffer),
|
|
reinterpret_cast<float*>(bias_->buffer),
|
|
reinterpret_cast<float*>(weights_->buffer), weights_->shape(),
|
|
reinterpret_cast<float*>(recurrent_weights_->buffer),
|
|
recurrent_weights_->shape(), activation_,
|
|
reinterpret_cast<float*>(output_->buffer));
|
|
memcpy(hidden_state_out_->buffer, output_->buffer,
|
|
sizeof(float) * getNumberOfElements(output_->shape()));
|
|
break;
|
|
}
|
|
default: {
|
|
LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
|
|
const T* biasData, const T* weightsData, const Shape& weightsShape,
|
|
const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
|
|
const int32_t activation, T* outputData) {
|
|
NNTRACE_COMP("RNN::Eval");
|
|
|
|
Shape dummyShape;
|
|
uint32_t numUnits = weightsShape.dimensions[0];
|
|
return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape,
|
|
hiddenStateInputData, biasData, weightsData, weightsShape,
|
|
/*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape,
|
|
recurrentWeightsData, recurrentWeightsShape, activation,
|
|
/*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData);
|
|
}
|
|
|
|
// A more general version of the RNNStep function.
|
|
// Auxiliary input is treated as if it was concatenated to a regular input and
|
|
// the result was multiplied by the weights matrix which was also concatenated
|
|
// with auxiliary weights.
|
|
template <typename T>
|
|
bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
|
|
const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData,
|
|
const T* weightsData, const Shape& weightsShape, const T* auxWeightsData,
|
|
const Shape& auxWeightsShape, const T* recurrentWeightsData,
|
|
const Shape& recurrentWeightsShape, const int32_t activation,
|
|
const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData,
|
|
T* hiddenStateOutput) {
|
|
NNTRACE_COMP("RNN::Eval");
|
|
|
|
const uint32_t batch_size = inputShape.dimensions[0];
|
|
const uint32_t num_units = weightsShape.dimensions[0];
|
|
const uint32_t input_size = inputShape.dimensions[1];
|
|
const uint32_t input_weights_stride = weightsShape.dimensions[1];
|
|
const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1];
|
|
|
|
uint32_t aux_input_size = 0;
|
|
uint32_t aux_input_weights_stride = 0;
|
|
bool hasAuxInput = (auxInputData != nullptr);
|
|
if (hasAuxInput) {
|
|
aux_input_size = auxInputShape.dimensions[1];
|
|
aux_input_weights_stride = auxWeightsShape.dimensions[1];
|
|
}
|
|
|
|
// For each batch
|
|
for (uint32_t b = 0; b < batch_size; b++) {
|
|
// Initialize the pointer to input, output and bias.
|
|
const T* input_ptr_batch = inputData + b * input_size;
|
|
const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units;
|
|
const T* aux_input_ptr_batch = nullptr;
|
|
if (hasAuxInput) {
|
|
aux_input_ptr_batch = auxInputData + b * aux_input_size;
|
|
}
|
|
T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset;
|
|
|
|
// Initialize input_weights and recurrent_weights.
|
|
const T* input_weights_ptr = weightsData;
|
|
const T* recurrent_weights_ptr = recurrentWeightsData;
|
|
const T* aux_input_weights_ptr = nullptr;
|
|
if (hasAuxInput) {
|
|
aux_input_weights_ptr = auxWeightsData;
|
|
}
|
|
|
|
// Output = bias
|
|
for (uint32_t o = 0; o < num_units; o++) {
|
|
output_ptr_batch[o] = biasData[o];
|
|
}
|
|
|
|
// Output += input * input_weights
|
|
for (uint32_t o = 0; o < num_units; o++) {
|
|
for (uint32_t i = 0; i < input_size; i++) {
|
|
output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
|
|
}
|
|
input_weights_ptr += input_weights_stride;
|
|
}
|
|
|
|
if (hasAuxInput) {
|
|
// Output += aux_input * aux_input_weights
|
|
for (uint32_t o = 0; o < num_units; o++) {
|
|
for (uint32_t i = 0; i < input_size; i++) {
|
|
output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i];
|
|
}
|
|
aux_input_weights_ptr += aux_input_weights_stride;
|
|
}
|
|
}
|
|
|
|
// Output += recurrent_weights * hidden_state
|
|
for (uint32_t o = 0; o < num_units; o++) {
|
|
for (uint32_t h = 0; h < num_units; h++) {
|
|
output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
|
|
}
|
|
recurrent_weights_ptr += recurrent_weights_stride;
|
|
}
|
|
|
|
// Output = activation(Output)
|
|
for (uint32_t o = 0; o < num_units; o++) {
|
|
output_ptr_batch[o] =
|
|
(ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]);
|
|
if (hiddenStateOutput != nullptr) {
|
|
*hiddenStateOutput = output_ptr_batch[o];
|
|
++hiddenStateOutput;
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
|
|
const _Float16* hiddenStateInputData, const _Float16* biasData,
|
|
const _Float16* weightsData, const Shape& weightsShape,
|
|
const _Float16* recurrentWeightsData,
|
|
const Shape& recurrentWeightsShape, int32_t activation,
|
|
_Float16* outputData);
|
|
template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
|
|
const _Float16* auxInputData, const Shape& auxInputShape,
|
|
const _Float16* hiddenStateInputData, const _Float16* biasData,
|
|
const _Float16* weightsData, const Shape& weightsShape,
|
|
const _Float16* auxWeightsData, const Shape& auxWeightsShape,
|
|
const _Float16* recurrentWeightsData,
|
|
const Shape& recurrentWeightsShape, const int32_t activation,
|
|
const uint32_t outputBatchStride,
|
|
const uint32_t outputBatchOffset, _Float16* outputData,
|
|
_Float16* hiddenStateOutput);
|
|
template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
|
|
const float* hiddenStateInputData, const float* biasData,
|
|
const float* weightsData, const Shape& weightsShape,
|
|
const float* recurrentWeightsData,
|
|
const Shape& recurrentWeightsShape, int32_t activation,
|
|
float* outputData);
|
|
template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
|
|
const float* auxInputData, const Shape& auxInputShape,
|
|
const float* hiddenStateInputData, const float* biasData,
|
|
const float* weightsData, const Shape& weightsShape,
|
|
const float* auxWeightsData, const Shape& auxWeightsShape,
|
|
const float* recurrentWeightsData,
|
|
const Shape& recurrentWeightsShape, int32_t activation,
|
|
uint32_t outputBatchStride, uint32_t outputBatchStep,
|
|
float* outputData, float* hiddenStateOutput);
|
|
|
|
} // namespace nn
|
|
} // namespace android
|