You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
4.3 KiB
123 lines
4.3 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "utils/tflite/encoder_common.h"
|
|
|
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
|
#include "tensorflow/lite/string_util.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
|
|
TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values) {
|
|
TfLiteIntArray* array_size = TfLiteIntArrayCreate(values.size());
|
|
int index = 0;
|
|
for (const int size : values) {
|
|
array_size->data[index++] = size;
|
|
}
|
|
return array_size;
|
|
}
|
|
|
|
TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
|
|
const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
|
|
int start_offset, TfLiteContext* context, TfLiteTensor* out) {
|
|
TF_LITE_ENSURE_EQ(context, in.dims->size, kEncoderInputRank);
|
|
TF_LITE_ENSURE_EQ(context, in.dims->data[0], kEncoderBatchSize);
|
|
const int output_size = out->dims->data[1];
|
|
int output_offset = 0;
|
|
for (int value_index = 0;
|
|
value_index < encoding_end_offsets.size() && output_offset < output_size;
|
|
++value_index) {
|
|
// Calculate how many elements need to be set with this value.
|
|
// The low bound depends on the offset from the beginning. If this is 0, it
|
|
// means that this value it truncated.
|
|
// The upper bound depends on how many elements are in the output tensor.
|
|
const int from_this_element =
|
|
std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
|
|
output_offset),
|
|
output_size - output_offset);
|
|
if (from_this_element == 0) {
|
|
continue;
|
|
}
|
|
|
|
switch (in.type) {
|
|
case kTfLiteInt32: {
|
|
std::fill(out->data.i32 + output_offset,
|
|
out->data.i32 + output_offset + from_this_element,
|
|
in.data.i32[value_index]);
|
|
} break;
|
|
case kTfLiteFloat32: {
|
|
std::fill(out->data.f + output_offset,
|
|
out->data.f + output_offset + from_this_element,
|
|
in.data.f[value_index]);
|
|
} break;
|
|
default:
|
|
context->ReportError(
|
|
(context), __FILE__ " Not supported attribute type %d", in.type);
|
|
return kTfLiteError;
|
|
}
|
|
output_offset += from_this_element;
|
|
}
|
|
// Do final padding.
|
|
switch (in.type) {
|
|
case kTfLiteInt32: {
|
|
const int32_t value =
|
|
(output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
|
|
std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
|
|
value);
|
|
} break;
|
|
case kTfLiteFloat32: {
|
|
const float value =
|
|
(output_offset > 0) ? out->data.f[output_offset - 1] : 0;
|
|
std::fill(out->data.f + output_offset, out->data.f + output_size, value);
|
|
} break;
|
|
default:
|
|
break;
|
|
}
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteStatus ResizeOutputTensor(const int max_output_length,
|
|
TfLiteTensor* tensor, TfLiteContext* context) {
|
|
TF_LITE_ENSURE_OK(
|
|
context, context->ResizeTensor(
|
|
context, tensor,
|
|
CreateIntArray({kEncoderBatchSize, max_output_length})));
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
|
|
const std::vector<int32_t>& data,
|
|
const int32_t padding_value,
|
|
TfLiteTensor* output_tensor) {
|
|
const int num_skip =
|
|
std::max(0, static_cast<int>(data.size()) - max_output_length);
|
|
int output_offset = 0;
|
|
int32_t* output_buffer = output_tensor->data.i32;
|
|
for (int i = num_skip; i < data.size(); ++i, ++output_offset) {
|
|
output_buffer[output_offset] = data[i];
|
|
}
|
|
|
|
// Do padding.
|
|
for (; output_offset < max_output_length; ++output_offset) {
|
|
output_buffer[output_offset] = padding_value;
|
|
}
|
|
|
|
// Return number of skipped entries from the beginning.
|
|
return num_skip;
|
|
}
|
|
|
|
} // namespace libtextclassifier3
|