You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
202 lines
6.4 KiB
202 lines
6.4 KiB
7 months ago
|
/*
|
||
|
* Copyright (C) 2018 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#include "lang_id/common/flatbuffers/model-utils.h"
|
||
|
|
||
|
#include <string.h>
|
||
|
|
||
|
#include <string>
|
||
|
|
||
|
#include "lang_id/common/lite_base/logging.h"
|
||
|
#include "lang_id/common/math/checksum.h"
|
||
|
|
||
|
namespace libtextclassifier3 {
|
||
|
namespace saft_fbs {
|
||
|
|
||
|
bool ClearlyFailsChecksum(const Model &model) {
|
||
|
if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
|
||
|
SAFTM_LOG(WARNING)
|
||
|
<< "No CRC32, most likely an old model; skip CRC32 check";
|
||
|
return false;
|
||
|
}
|
||
|
const mobile::uint32 expected_crc32 = model.crc32();
|
||
|
const mobile::uint32 actual_crc32 = ComputeCrc2Checksum(&model);
|
||
|
if (actual_crc32 != expected_crc32) {
|
||
|
SAFTM_LOG(ERROR) << "Corrupt model: different CRC32: " << actual_crc32
|
||
|
<< " vs " << expected_crc32;
|
||
|
return true;
|
||
|
}
|
||
|
SAFTM_DLOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
|
||
|
if ((data == nullptr) || (num_bytes == 0)) {
|
||
|
SAFTM_LOG(ERROR) << "GetModel called on an empty sequence of bytes";
|
||
|
return nullptr;
|
||
|
}
|
||
|
const uint8_t *start = reinterpret_cast<const uint8_t *>(data);
|
||
|
flatbuffers::Verifier verifier(start, num_bytes);
|
||
|
if (!VerifyModelBuffer(verifier)) {
|
||
|
SAFTM_LOG(ERROR) << "Not a valid Model flatbuffer";
|
||
|
return nullptr;
|
||
|
}
|
||
|
const Model *model = GetModel(start);
|
||
|
if (model == nullptr) {
|
||
|
return nullptr;
|
||
|
}
|
||
|
if (ClearlyFailsChecksum(*model)) {
|
||
|
return nullptr;
|
||
|
}
|
||
|
return model;
|
||
|
}
|
||
|
|
||
|
const ModelInput *GetInputByName(const Model *model, const std::string &name) {
|
||
|
if (model == nullptr) {
|
||
|
SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
|
||
|
return nullptr;
|
||
|
}
|
||
|
const auto *inputs = model->inputs();
|
||
|
if (inputs == nullptr) {
|
||
|
// We should always have a list of inputs; maybe an empty one, if no inputs,
|
||
|
// but the list should be there.
|
||
|
SAFTM_LOG(ERROR) << "null inputs";
|
||
|
return nullptr;
|
||
|
}
|
||
|
for (const ModelInput *input : *inputs) {
|
||
|
if (input != nullptr) {
|
||
|
const flatbuffers::String *input_name = input->name();
|
||
|
if (input_name && input_name->str() == name) {
|
||
|
return input;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
mobile::StringPiece GetInputBytes(const ModelInput *input) {
|
||
|
if ((input == nullptr) || (input->data() == nullptr)) {
|
||
|
SAFTM_LOG(ERROR) << "ModelInput has no content";
|
||
|
return mobile::StringPiece(nullptr, 0);
|
||
|
}
|
||
|
const flatbuffers::Vector<uint8_t> *input_data = input->data();
|
||
|
if (input_data == nullptr) {
|
||
|
SAFTM_LOG(ERROR) << "null input data";
|
||
|
return mobile::StringPiece(nullptr, 0);
|
||
|
}
|
||
|
return mobile::StringPiece(reinterpret_cast<const char *>(input_data->data()),
|
||
|
input_data->size());
|
||
|
}
|
||
|
|
||
|
bool FillParameters(const Model &model, mobile::TaskContext *context) {
|
||
|
if (context == nullptr) {
|
||
|
SAFTM_LOG(ERROR) << "null context";
|
||
|
return false;
|
||
|
}
|
||
|
const auto *parameters = model.parameters();
|
||
|
if (parameters == nullptr) {
|
||
|
// We should always have a list of parameters; maybe an empty one, if no
|
||
|
// parameters, but the list should be there.
|
||
|
SAFTM_LOG(ERROR) << "null list of parameters";
|
||
|
return false;
|
||
|
}
|
||
|
for (const ModelParameter *p : *parameters) {
|
||
|
if (p == nullptr) {
|
||
|
SAFTM_LOG(ERROR) << "null parameter";
|
||
|
return false;
|
||
|
}
|
||
|
if (p->name() == nullptr) {
|
||
|
SAFTM_LOG(ERROR) << "null parameter name";
|
||
|
return false;
|
||
|
}
|
||
|
const std::string name = p->name()->str();
|
||
|
if (name.empty()) {
|
||
|
SAFTM_LOG(ERROR) << "empty parameter name";
|
||
|
return false;
|
||
|
}
|
||
|
if (p->value() == nullptr) {
|
||
|
SAFTM_LOG(ERROR) << "null parameter name";
|
||
|
return false;
|
||
|
}
|
||
|
context->SetParameter(name, p->value()->str());
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
namespace {
|
||
|
// Updates |*crc| with the information from |s|. Auxiliary for
|
||
|
// ComputeCrc2Checksum.
|
||
|
//
|
||
|
// The bytes from |info| are also used to update the CRC32 checksum. |info|
|
||
|
// should be a brief tag that indicates what |s| represents. The idea is to add
|
||
|
// some structure to the information that goes into the CRC32 computation.
|
||
|
template <typename T>
|
||
|
void UpdateCrc(mobile::Crc32 *crc, const flatbuffers::Vector<T> *s,
|
||
|
mobile::StringPiece info) {
|
||
|
crc->Update("|");
|
||
|
crc->Update(info.data(), info.size());
|
||
|
crc->Update(":");
|
||
|
if (s == nullptr) {
|
||
|
crc->Update("empty");
|
||
|
} else {
|
||
|
crc->Update(reinterpret_cast<const char *>(s->data()),
|
||
|
s->size() * sizeof(T));
|
||
|
}
|
||
|
}
|
||
|
} // namespace
|
||
|
|
||
|
mobile::uint32 ComputeCrc2Checksum(const Model *model) {
|
||
|
// Implementation note: originally, I (salcianu@) thought we can just compute
|
||
|
// a CRC32 checksum of the model bytes. Unfortunately, the expected checksum
|
||
|
// is there too (and because we don't control the flatbuffer format, we can't
|
||
|
// "arrange" for it to be placed at the head / tail of those bytes). Instead,
|
||
|
// we traverse |model| and feed into the CRC32 computation those parts we are
|
||
|
// interested in (which excludes the crc32 field).
|
||
|
//
|
||
|
// Note: storing the checksum outside the Model would be too disruptive for
|
||
|
// the way we currently ship our models.
|
||
|
mobile::Crc32 crc;
|
||
|
if (model == nullptr) {
|
||
|
return crc.Get();
|
||
|
}
|
||
|
crc.Update("|Parameters:");
|
||
|
const auto *parameters = model->parameters();
|
||
|
if (parameters != nullptr) {
|
||
|
for (const ModelParameter *p : *parameters) {
|
||
|
if (p != nullptr) {
|
||
|
UpdateCrc(&crc, p->name(), "name");
|
||
|
UpdateCrc(&crc, p->value(), "value");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
crc.Update("|Inputs:");
|
||
|
const auto *inputs = model->inputs();
|
||
|
if (inputs != nullptr) {
|
||
|
for (const ModelInput *input : *inputs) {
|
||
|
if (input != nullptr) {
|
||
|
UpdateCrc(&crc, input->name(), "name");
|
||
|
UpdateCrc(&crc, input->type(), "type");
|
||
|
UpdateCrc(&crc, input->sub_type(), "sub-type");
|
||
|
UpdateCrc(&crc, input->data(), "data");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return crc.Get();
|
||
|
}
|
||
|
|
||
|
} // namespace saft_fbs
|
||
|
} // namespace nlp_saft
|