You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
8.1 KiB
207 lines
8.1 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "utils/testing/annotator.h"
|
|
|
|
#include "utils/flatbuffers/mutable.h"
|
|
#include "flatbuffers/reflection.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
|
|
std::string FirstResult(const std::vector<ClassificationResult>& results) {
|
|
if (results.empty()) {
|
|
return "<INVALID RESULTS>";
|
|
}
|
|
return results[0].collection;
|
|
}
|
|
|
|
std::string ReadFile(const std::string& file_name) {
|
|
std::ifstream file_stream(file_name);
|
|
return std::string(std::istreambuf_iterator<char>(file_stream), {});
|
|
}
|
|
|
|
std::unique_ptr<RegexModel_::PatternT> MakePattern(
|
|
const std::string& collection_name, const std::string& pattern,
|
|
const bool enabled_for_classification, const bool enabled_for_selection,
|
|
const bool enabled_for_annotation, const float score,
|
|
const float priority_score) {
|
|
std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
|
|
result->collection_name = collection_name;
|
|
result->pattern = pattern;
|
|
// We cannot directly operate with |= on the flag, so use an int here.
|
|
int enabled_modes = ModeFlag_NONE;
|
|
if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
|
|
if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
|
|
if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
|
|
result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
|
|
result->target_classification_score = score;
|
|
result->priority_score = priority_score;
|
|
return result;
|
|
}
|
|
|
|
// Shortcut function that doesn't need to specify the priority score.
|
|
std::unique_ptr<RegexModel_::PatternT> MakePattern(
|
|
const std::string& collection_name, const std::string& pattern,
|
|
const bool enabled_for_classification, const bool enabled_for_selection,
|
|
const bool enabled_for_annotation, const float score) {
|
|
return MakePattern(collection_name, pattern, enabled_for_classification,
|
|
enabled_for_selection, enabled_for_annotation,
|
|
/*score=*/score,
|
|
/*priority_score=*/score);
|
|
}
|
|
|
|
void AddTestRegexModel(ModelT* unpacked_model) {
|
|
// Add test regex models.
|
|
unpacked_model->regex_model->patterns.push_back(MakePattern(
|
|
"person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old",
|
|
/*enabled_for_classification=*/true,
|
|
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0));
|
|
|
|
// Use meta data to generate custom serialized entity data.
|
|
MutableFlatbufferBuilder entity_data_builder(
|
|
flatbuffers::GetRoot<reflection::Schema>(
|
|
unpacked_model->entity_data_schema.data()));
|
|
RegexModel_::PatternT* pattern =
|
|
unpacked_model->regex_model->patterns.back().get();
|
|
|
|
{
|
|
std::unique_ptr<MutableFlatbuffer> entity_data =
|
|
entity_data_builder.NewRoot();
|
|
entity_data->Set("is_alive", true);
|
|
pattern->serialized_entity_data = entity_data->Serialize();
|
|
}
|
|
pattern->capturing_group.emplace_back(new CapturingGroupT);
|
|
pattern->capturing_group.emplace_back(new CapturingGroupT);
|
|
pattern->capturing_group.emplace_back(new CapturingGroupT);
|
|
pattern->capturing_group.emplace_back(new CapturingGroupT);
|
|
// Group 0 is the full match, capturing groups starting at 1.
|
|
pattern->capturing_group[1]->entity_field_path.reset(
|
|
new FlatbufferFieldPathT);
|
|
pattern->capturing_group[1]->entity_field_path->field.emplace_back(
|
|
new FlatbufferFieldT);
|
|
pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
|
|
"first_name";
|
|
pattern->capturing_group[2]->entity_field_path.reset(
|
|
new FlatbufferFieldPathT);
|
|
pattern->capturing_group[2]->entity_field_path->field.emplace_back(
|
|
new FlatbufferFieldT);
|
|
pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
|
|
"last_name";
|
|
// Set `former_us_president` field if we match Obama.
|
|
{
|
|
std::unique_ptr<MutableFlatbuffer> entity_data =
|
|
entity_data_builder.NewRoot();
|
|
entity_data->Set("former_us_president", true);
|
|
pattern->capturing_group[2]->serialized_entity_data =
|
|
entity_data->Serialize();
|
|
}
|
|
pattern->capturing_group[3]->entity_field_path.reset(
|
|
new FlatbufferFieldPathT);
|
|
pattern->capturing_group[3]->entity_field_path->field.emplace_back(
|
|
new FlatbufferFieldT);
|
|
pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
|
|
"age";
|
|
}
|
|
|
|
std::string CreateEmptyModel(
|
|
const std::function<void(ModelT* model)> model_update_fn) {
|
|
ModelT model;
|
|
model_update_fn(&model);
|
|
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
FinishModelBuffer(builder, Model::Pack(builder, &model));
|
|
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
|
builder.GetSize());
|
|
}
|
|
|
|
// Create fake entity data schema meta data.
|
|
void AddTestEntitySchemaData(ModelT* unpacked_model) {
|
|
// Cannot use object oriented API here as that is not available for the
|
|
// reflection schema.
|
|
flatbuffers::FlatBufferBuilder schema_builder;
|
|
std::vector<flatbuffers::Offset<reflection::Field>> fields = {
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("first_name"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::String),
|
|
/*id=*/0,
|
|
/*offset=*/4),
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("is_alive"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::Bool),
|
|
/*id=*/1,
|
|
/*offset=*/6),
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("last_name"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::String),
|
|
/*id=*/2,
|
|
/*offset=*/8),
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("age"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::Int),
|
|
/*id=*/3,
|
|
/*offset=*/10),
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("former_us_president"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::Bool),
|
|
/*id=*/4,
|
|
/*offset=*/12)};
|
|
std::vector<flatbuffers::Offset<reflection::Enum>> enums;
|
|
std::vector<flatbuffers::Offset<reflection::Object>> objects = {
|
|
reflection::CreateObject(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("EntityData"),
|
|
/*fields=*/
|
|
schema_builder.CreateVectorOfSortedTables(&fields))};
|
|
schema_builder.Finish(reflection::CreateSchema(
|
|
schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
|
|
schema_builder.CreateVectorOfSortedTables(&enums),
|
|
/*(unused) file_ident=*/0,
|
|
/*(unused) file_ext=*/0,
|
|
/*root_table*/ objects[0]));
|
|
|
|
unpacked_model->entity_data_schema.assign(
|
|
schema_builder.GetBufferPointer(),
|
|
schema_builder.GetBufferPointer() + schema_builder.GetSize());
|
|
}
|
|
|
|
AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
|
|
const std::string& collection,
|
|
const float score,
|
|
AnnotatedSpan::Source source) {
|
|
AnnotatedSpan result;
|
|
result.span = span;
|
|
result.classification.push_back({collection, score});
|
|
result.source = source;
|
|
return result;
|
|
}
|
|
|
|
} // namespace libtextclassifier3
|