/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/annotator_test-include.h" #include #include #include #include #include "annotator/annotator.h" #include "annotator/collections.h" #include "annotator/model_generated.h" #include "annotator/test-utils.h" #include "annotator/types-test-util.h" #include "annotator/types.h" #include "utils/grammar/utils/locale-shard-map.h" #include "utils/grammar/utils/rules.h" #include "utils/testing/annotator.h" #include "lang_id/fb_model/lang-id-from-fb.h" #include "lang_id/lang-id.h" namespace libtextclassifier3 { namespace test_internal { using ::testing::Contains; using ::testing::ElementsAre; using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::UnorderedElementsAreArray; std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; } std::string GetModelWithVocabPath() { return GetModelPath() + "test_vocab_model.fb"; } std::string GetTestModelWithDatetimeRegEx() { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->datetime_grammar_model.reset(nullptr); }); return model_buffer; } void ExpectFirstEntityIsMoney(const std::vector& result, const std::string& currency, const std::string& amount, const int whole_part, const int decimal_part, const int nanos) { ASSERT_GT(result.size(), 0); ASSERT_GT(result[0].classification.size(), 0); ASSERT_EQ(result[0].classification[0].collection, "money"); const EntityData* entity_data = GetEntityData(result[0].classification[0].serialized_entity_data.data()); ASSERT_NE(entity_data, nullptr); ASSERT_NE(entity_data->money(), nullptr); EXPECT_EQ(entity_data->money()->unnormalized_currency()->str(), currency); EXPECT_EQ(entity_data->money()->unnormalized_amount()->str(), amount); EXPECT_EQ(entity_data->money()->amount_whole_part(), whole_part); EXPECT_EQ(entity_data->money()->amount_decimal_part(), decimal_part); EXPECT_EQ(entity_data->money()->nanos(), nanos); } TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) { std::unique_ptr classifier = Annotator::FromPath(GetModelPath() + "wrong_embeddings.fb", unilib_.get(), calendarlib_.get()); EXPECT_FALSE(classifier); } void VerifyClassifyText(const Annotator* classifier) { ASSERT_TRUE(classifier); EXPECT_EQ("other", FirstResult(classifier->ClassifyText( "this afternoon Barack Obama gave a speech at", {15, 27}))); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( "Call me at (800) 123-456 today", {11, 24}))); // More lines. EXPECT_EQ("other", FirstResult(classifier->ClassifyText( "this afternoon Barack Obama gave a speech at|Visit " "www.google.com every today!|Call me at (800) 123-456 today.", {15, 27}))); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( "this afternoon Barack Obama gave a speech at|Visit " "www.google.com every today!|Call me at (800) 123-456 today.", {90, 103}))); // Single word. EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5}))); EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4}))); // Junk. These should not crash the test. classifier->ClassifyText("", {0, 0}); classifier->ClassifyText("asdf", {0, 0}); classifier->ClassifyText("asdf", {0, 27}); classifier->ClassifyText("asdf", {-30, 300}); classifier->ClassifyText("asdf", {-10, -1}); classifier->ClassifyText("asdf", {100, 17}); classifier->ClassifyText("a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5}); // Test invalid utf8 input. EXPECT_EQ("", FirstResult(classifier->ClassifyText( "\xf0\x9f\x98\x8b\x8b", {0, 0}))); } TEST_F(AnnotatorTest, ClassifyText) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyClassifyText(classifier.get()); } TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 7}))); ClassificationOptions classification_options; classification_options.detected_text_language_tags = "en"; EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText( "isotope", {0, 7}, classification_options))); classification_options.detected_text_language_tags = "uz"; EXPECT_EQ("other", FirstResult(classifier->ClassifyText( "isotope", {0, 7}, classification_options))); } TEST_F(AnnotatorTest, ClassifyTextUseVocabAnnotatorWithoutVocabModel) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); ClassificationOptions classification_options; classification_options.detected_text_language_tags = "en"; classification_options.use_vocab_annotator = true; EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText( "isotope", {0, 7}, classification_options))); } #ifdef TC3_VOCAB_ANNOTATOR_IMPL TEST_F(AnnotatorTest, ClassifyTextWithVocabModel) { std::unique_ptr classifier = Annotator::FromPath( GetModelWithVocabPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); ClassificationOptions classification_options; classification_options.detected_text_language_tags = "en"; // The FFModel model does not annotate "integrity" as "dictionary", but the // vocab annotator does. So we can use that to check if the vocab annotator is // in use. classification_options.use_vocab_annotator = true; EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText( "integrity", {0, 9}, classification_options))); classification_options.use_vocab_annotator = false; EXPECT_EQ("other", FirstResult(classifier->ClassifyText( "integrity", {0, 9}, classification_options))); } #endif // TC3_VOCAB_ANNOTATOR_IMPL TEST_F(AnnotatorTest, ClassifyTextDisabledFail) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); TC3_CHECK(unpacked_model != nullptr); unpacked_model->classification_model.clear(); unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); // The classification model is still needed for selection scores. ASSERT_FALSE(classifier); } TEST_F(AnnotatorTest, ClassifyTextDisabled) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); unpacked_model->enabled_modes = ModeFlag_ANNOTATION_AND_SELECTION; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_THAT( classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}), IsEmpty()); } TEST_F(AnnotatorTest, ClassifyTextFilteredCollections) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr classifier = Annotator::FromUnownedBuffer( test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( "Call me at (800) 123-456 today", {11, 24}))); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); unpacked_model->output_options.reset(new OutputOptionsT); // Disable phone classification unpacked_model->output_options->filtered_collections_classification.push_back( "phone"); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ("other", FirstResult(classifier->ClassifyText( "Call me at (800) 123-456 today", {11, 24}))); // Check that the address classification still passes. EXPECT_EQ("address", FirstResult(classifier->ClassifyText( "350 Third Street, Cambridge", {0, 27}))); } TEST_F(AnnotatorTest, ClassifyTextRegularExpression) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test regex models. unpacked_model->regex_model->patterns.push_back(MakePattern( "person", "Barack Obama", /*enabled_for_classification=*/true, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0)); unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5)); std::unique_ptr verified_pattern = MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}", /*enabled_for_classification=*/true, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0); verified_pattern->verification_options.reset(new VerificationOptionsT); verified_pattern->verification_options->verify_luhn_checksum = true; unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ("flight", FirstResult(classifier->ClassifyText( "Your flight LX373 is delayed by 3 hours.", {12, 17}))); EXPECT_EQ("person", FirstResult(classifier->ClassifyText( "this afternoon Barack Obama gave a speech at", {15, 27}))); EXPECT_EQ("email", FirstResult(classifier->ClassifyText("you@android.com", {0, 15}))); EXPECT_EQ("email", FirstResult(classifier->ClassifyText( "Contact me at you@android.com", {14, 29}))); EXPECT_EQ("url", FirstResult(classifier->ClassifyText( "Visit www.google.com every today!", {6, 20}))); EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5}))); EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd", {7, 12}))); EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText( "cc: 4012 8888 8888 1881", {4, 23}))); EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText( "2221 0067 4735 6281", {0, 19}))); // Luhn check fails. EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282", {0, 19}))); // More lines. EXPECT_EQ("url", FirstResult(classifier->ClassifyText( "this afternoon Barack Obama gave a speech at|Visit " "www.google.com every today!|Call me at (800) 123-456 today.", {51, 65}))); } #ifndef TC3_DISABLE_LUA TEST_F(AnnotatorTest, ClassifyTextRegularExpressionLuaVerification) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test regex models. std::unique_ptr verified_pattern = MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})", /*enabled_for_classification=*/true, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0); verified_pattern->verification_options.reset(new VerificationOptionsT); verified_pattern->verification_options->lua_verifier = 0; unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); unpacked_model->regex_model->lua_verifier.push_back( "return match[2].text==\"99\""); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // Custom rule triggers and is correctly verified. EXPECT_EQ("parcel_tracking", FirstResult(classifier->ClassifyText( "99-00-123456-12345678", {0, 21}))); // Custom verification fails. EXPECT_EQ("other", FirstResult(classifier->ClassifyText( "90-00-123456-12345678", {0, 21}))); } #endif // TC3_DISABLE_LUA TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityData) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add fake entity schema metadata. AddTestEntitySchemaData(unpacked_model.get()); AddTestRegexModel(unpacked_model.get()); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // Check with full name. { auto classifications = classifier->ClassifyText("Barack Obama is 57 years old", {0, 28}); EXPECT_EQ(1, classifications.size()); EXPECT_EQ("person_with_age", classifications[0].collection); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( classifications[0].serialized_entity_data.data())); EXPECT_EQ( entity->GetPointer(/*field=*/4)->str(), "Barack"); EXPECT_EQ( entity->GetPointer(/*field=*/8)->str(), "Obama"); // Check `age`. EXPECT_EQ(entity->GetField(/*field=*/10, /*defaultval=*/0), 57); // Check `is_alive`. EXPECT_TRUE(entity->GetField(/*field=*/6, /*defaultval=*/false)); // Check `former_us_president`. EXPECT_TRUE(entity->GetField(/*field=*/12, /*defaultval=*/false)); } // Check only with first name. { auto classifications = classifier->ClassifyText("Barack is 57 years old", {0, 22}); EXPECT_EQ(1, classifications.size()); EXPECT_EQ("person_with_age", classifications[0].collection); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( classifications[0].serialized_entity_data.data())); EXPECT_EQ( entity->GetPointer(/*field=*/4)->str(), "Barack"); // Check `age`. EXPECT_EQ(entity->GetField(/*field=*/10, /*defaultval=*/0), 57); // Check `is_alive`. EXPECT_TRUE(entity->GetField(/*field=*/6, /*defaultval=*/false)); // Check `former_us_president`. EXPECT_FALSE(entity->GetField(/*field=*/12, /*defaultval=*/false)); } } TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityDataNormalization) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add fake entity schema metadata. AddTestEntitySchemaData(unpacked_model.get()); AddTestRegexModel(unpacked_model.get()); // Upper case last name as post-processing. RegexModel_::PatternT* pattern = unpacked_model->regex_model->patterns.back().get(); pattern->capturing_group[2]->normalization_options.reset( new NormalizationOptionsT); pattern->capturing_group[2] ->normalization_options->codepointwise_normalization = NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); auto classifications = classifier->ClassifyText("Barack Obama is 57 years old", {0, 28}); EXPECT_EQ(1, classifications.size()); EXPECT_EQ("person_with_age", classifications[0].collection); // Check entity data normalization. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( classifications[0].serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/8)->str(), "OBAMA"); } TEST_F(AnnotatorTest, ClassifyTextPriorityResolution) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get())); // Add test regex models. unpacked_model->regex_model->patterns.clear(); unpacked_model->regex_model->patterns.push_back(MakePattern( "flight1", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, /*score=*/1.0, /*priority_score=*/1.0)); unpacked_model->regex_model->patterns.push_back(MakePattern( "flight2", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, /*score=*/1.0, /*priority_score=*/0.0)); { flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ("flight1", FirstResult(classifier->ClassifyText( "Your flight LX373 is delayed by 3 hours.", {12, 17}))); } unpacked_model->regex_model->patterns.back()->priority_score = 3.0; { flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ("flight2", FirstResult(classifier->ClassifyText( "Your flight LX373 is delayed by 3 hours.", {12, 17}))); } } TEST_F(AnnotatorTest, AnnotatePriorityResolution) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get())); // Add test regex models. One of them has higher priority score than // the other. We'll test that always the one with higher priority score // ends up winning. unpacked_model->regex_model->patterns.clear(); const std::string flight_regex = "([a-zA-Z]{2}\\d{2,4})"; unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", flight_regex, /*enabled_for_classification=*/true, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, /*score=*/1.0, /*priority_score=*/1.0)); unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", flight_regex, /*enabled_for_classification=*/true, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, /*score=*/1.0, /*priority_score=*/0.0)); // "flight" that wins should have a priority score of 1.0. { flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::vector results = classifier->Annotate("Your flight LX373 is delayed by 3 hours."); ASSERT_THAT(results, Not(IsEmpty())); EXPECT_THAT(results[0].classification, Not(IsEmpty())); EXPECT_GE(results[0].classification[0].priority_score, 0.9); } // When we increase the priority score, the "flight" that wins should have a // priority score of 3.0. unpacked_model->regex_model->patterns.back()->priority_score = 3.0; { flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::vector results = classifier->Annotate("Your flight LX373 is delayed by 3 hours."); ASSERT_THAT(results, Not(IsEmpty())); EXPECT_THAT(results[0].classification, Not(IsEmpty())); EXPECT_GE(results[0].classification[0].priority_score, 2.9); } } TEST_F(AnnotatorTest, SuggestSelectionRegularExpression) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test regex models. unpacked_model->regex_model->patterns.push_back(MakePattern( "person", " (Barack Obama) ", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); unpacked_model->regex_model->patterns.back()->priority_score = 1.1; std::unique_ptr verified_pattern = MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0); verified_pattern->verification_options.reset(new VerificationOptionsT); verified_pattern->verification_options->verify_luhn_checksum = true; unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // Check regular expression selection. EXPECT_EQ(classifier->SuggestSelection( "Your flight MA 0123 is delayed by 3 hours.", {12, 14}), CodepointSpan(12, 19)); EXPECT_EQ(classifier->SuggestSelection( "this afternoon Barack Obama gave a speech at", {15, 21}), CodepointSpan(15, 27)); EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}), CodepointSpan(4, 23)); } TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test regex models. std::unique_ptr custom_selection_bounds_pattern = MakePattern("date_range", "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to " "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0); custom_selection_bounds_pattern->capturing_group.emplace_back( new CapturingGroupT); custom_selection_bounds_pattern->capturing_group.emplace_back( new CapturingGroupT); custom_selection_bounds_pattern->capturing_group.emplace_back( new CapturingGroupT); custom_selection_bounds_pattern->capturing_group.emplace_back( new CapturingGroupT); custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false; custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true; custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true; custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true; unpacked_model->regex_model->patterns.push_back( std::move(custom_selection_bounds_pattern)); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // Check regular expression selection. EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797", {21, 23}), CodepointSpan(10, 34)); EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}), CodepointSpan(9, 17)); } TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test regex models. unpacked_model->regex_model->patterns.push_back(MakePattern( "person", " (Barack Obama) ", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); unpacked_model->regex_model->patterns.back()->priority_score = 0.5; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // Check conflict resolution. EXPECT_EQ( classifier->SuggestSelection( "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123", {55, 57}), CodepointSpan(26, 62)); } TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test regex models. unpacked_model->regex_model->patterns.push_back(MakePattern( "person", " (Barack Obama) ", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); unpacked_model->regex_model->patterns.back()->priority_score = 1.1; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // Check conflict resolution. EXPECT_EQ( classifier->SuggestSelection( "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123", {55, 57}), CodepointSpan(55, 62)); } TEST_F(AnnotatorTest, AnnotateRegex) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test regex models. unpacked_model->regex_model->patterns.push_back(MakePattern( "person", " (Barack Obama) ", /*enabled_for_classification=*/false, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0)); unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5)); std::unique_ptr verified_pattern = MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0); verified_pattern->verification_options.reset(new VerificationOptionsT); verified_pattern->verification_options->verify_luhn_checksum = true; unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({IsAnnotatedSpan(6, 18, "person"), IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), IsAnnotatedSpan(107, 126, "payment_card")})); } TEST_F(AnnotatorTest, AnnotatesFlightNumbers) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // ICAO is only used for selected airlines. // Expected: LX373, EZY1234 and U21234. const std::string test_string = "flights LX373, SWR373, EZY1234, U21234"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({IsAnnotatedSpan(8, 13, "flight"), IsAnnotatedSpan(23, 30, "flight"), IsAnnotatedSpan(32, 38, "flight")})); } #ifndef TC3_DISABLE_LUA TEST_F(AnnotatorTest, AnnotateRegexLuaVerification) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test regex models. std::unique_ptr verified_pattern = MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})", /*enabled_for_classification=*/true, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0); verified_pattern->verification_options.reset(new VerificationOptionsT); verified_pattern->verification_options->lua_verifier = 0; unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); unpacked_model->regex_model->lua_verifier.push_back( "return match[2].text==\"99\""); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "your parcel is on the way: 99-00-123456-12345678"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({IsAnnotatedSpan(27, 48, "parcel_tracking")})); } #endif // TC3_DISABLE_LUA TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityData) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add fake entity schema metadata. AddTestEntitySchemaData(unpacked_model.get()); AddTestRegexModel(unpacked_model.get()); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.is_serialized_entity_data_enabled = true; auto annotations = classifier->Annotate("Barack Obama is 57 years old", options); EXPECT_EQ(1, annotations.size()); EXPECT_EQ(1, annotations[0].classification.size()); EXPECT_EQ("person_with_age", annotations[0].classification[0].collection); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( annotations[0].classification[0].serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/4)->str(), "Barack"); EXPECT_EQ(entity->GetPointer(/*field=*/8)->str(), "Obama"); // Check `age`. EXPECT_EQ(entity->GetField(/*field=*/10, /*defaultval=*/0), 57); // Check `is_alive`. EXPECT_TRUE(entity->GetField(/*field=*/6, /*defaultval=*/false)); // Check `former_us_president`. EXPECT_TRUE(entity->GetField(/*field=*/12, /*defaultval=*/false)); } TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataNormalization) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add fake entity schema metadata. AddTestEntitySchemaData(unpacked_model.get()); AddTestRegexModel(unpacked_model.get()); // Upper case last name as post-processing. RegexModel_::PatternT* pattern = unpacked_model->regex_model->patterns.back().get(); pattern->capturing_group[2]->normalization_options.reset( new NormalizationOptionsT); pattern->capturing_group[2] ->normalization_options->codepointwise_normalization = NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.is_serialized_entity_data_enabled = true; auto annotations = classifier->Annotate("Barack Obama is 57 years old", options); EXPECT_EQ(1, annotations.size()); EXPECT_EQ(1, annotations[0].classification.size()); EXPECT_EQ("person_with_age", annotations[0].classification[0].collection); // Check normalization. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( annotations[0].classification[0].serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/8)->str(), "OBAMA"); } TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataDisabled) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add fake entity schema metadata. AddTestEntitySchemaData(unpacked_model.get()); AddTestRegexModel(unpacked_model.get()); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.is_serialized_entity_data_enabled = false; auto annotations = classifier->Annotate("Barack Obama is 57 years old", options); EXPECT_EQ(1, annotations.size()); EXPECT_EQ(1, annotations[0].classification.size()); EXPECT_EQ("person_with_age", annotations[0].classification[0].collection); // Check entity data. EXPECT_EQ("", annotations[0].classification[0].serialized_entity_data); } TEST_F(AnnotatorTest, PhoneFiltering) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( "phone: (123) 456 789", {7, 20}))); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( "phone: (123) 456 789,0001112", {7, 25}))); EXPECT_EQ("other", FirstResult(classifier->ClassifyText( "phone: (123) 456 789,0001112", {7, 28}))); } TEST_F(AnnotatorTest, SuggestSelection) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(classifier->SuggestSelection( "this afternoon Barack Obama gave a speech at", {15, 21}), CodepointSpan(15, 21)); // Try passing whole string. // If more than 1 token is specified, we should return back what entered. EXPECT_EQ( classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}), CodepointSpan(0, 27)); // Single letter. EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), CodepointSpan(0, 1)); // Single word. EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), CodepointSpan(0, 4)); EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), CodepointSpan(11, 23)); // Unpaired bracket stripping. EXPECT_EQ( classifier->SuggestSelection("call me at (857) 225 3556 today", {12, 14}), CodepointSpan(11, 25)); EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {12, 14}), CodepointSpan(12, 15)); EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {12, 14}), CodepointSpan(11, 15)); EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {12, 14}), CodepointSpan(12, 15)); // If the resulting selection would be empty, the original span is returned. EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}), CodepointSpan(11, 13)); EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}), CodepointSpan(11, 12)); EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}), CodepointSpan(11, 12)); // If the original span is larger than the found selection, the original span // is returned. EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556 today", {5, 24}), CodepointSpan(5, 24)); } TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Disable the selection model. unpacked_model->selection_model.clear(); unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); // Selection model needs to be present for annotation. ASSERT_FALSE(classifier); } TEST_F(AnnotatorTest, SuggestSelectionDisabled) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Disable the selection model. unpacked_model->selection_model.clear(); unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION; unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION; // Disable the number annotator. With the selection model disabled, there is // no feature processor, which is required for the number annotator. unpacked_model->number_annotator_options->enabled = false; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), CodepointSpan(11, 14)); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( "call me at (800) 123-456 today", {11, 24}))); EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"), IsEmpty()); } TEST_F(AnnotatorTest, SuggestSelectionFilteredCollections) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr classifier = Annotator::FromUnownedBuffer( test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), CodepointSpan(11, 23)); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); unpacked_model->output_options.reset(new OutputOptionsT); // Disable phone selection unpacked_model->output_options->filtered_collections_selection.push_back( "phone"); // We need to force this for filtering. unpacked_model->selection_options->always_classify_suggested_selection = true; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), CodepointSpan(11, 14)); // Address selection should still work. EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}), CodepointSpan(0, 27)); } TEST_F(AnnotatorTest, SuggestSelectionsAreSymmetric) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}), CodepointSpan(0, 27)); EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}), CodepointSpan(0, 27)); EXPECT_EQ( classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}), CodepointSpan(0, 27)); EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge", {16, 22}), CodepointSpan(6, 33)); } TEST_F(AnnotatorTest, SuggestSelectionWithNewLine) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}), CodepointSpan(4, 16)); EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}), CodepointSpan(0, 12)); SelectionOptions options; EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options), CodepointSpan(0, 12)); } TEST_F(AnnotatorTest, SuggestSelectionWithPunctuation) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // From the right. EXPECT_EQ(classifier->SuggestSelection( "this afternoon BarackObama, gave a speech at", {15, 26}), CodepointSpan(15, 26)); // From the right multiple. EXPECT_EQ(classifier->SuggestSelection( "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}), CodepointSpan(15, 26)); // From the left multiple. EXPECT_EQ(classifier->SuggestSelection( "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}), CodepointSpan(21, 32)); // From both sides. EXPECT_EQ(classifier->SuggestSelection( "this afternoon !BarackObama,- gave a speech at", {16, 27}), CodepointSpan(16, 27)); } TEST_F(AnnotatorTest, SuggestSelectionNoCrashWithJunk) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // Try passing in bunch of invalid selections. EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), CodepointSpan(0, 27)); EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}), CodepointSpan(-10, 27)); EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}), CodepointSpan(0, 27)); EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}), CodepointSpan(-30, 300)); EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}), CodepointSpan(-10, -1)); EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}), CodepointSpan(100, 17)); // Try passing invalid utf8. EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}), CodepointSpan(-1, -1)); } TEST_F(AnnotatorTest, SuggestSelectionSelectSpace) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}), CodepointSpan(11, 23)); EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}), CodepointSpan(10, 11)); EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}), CodepointSpan(23, 24)); EXPECT_EQ( classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}), CodepointSpan(23, 24)); EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today", {14, 17}), CodepointSpan(11, 25)); EXPECT_EQ( classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}), CodepointSpan(11, 23)); EXPECT_EQ( classifier->SuggestSelection( "let's meet at 350 Third Street Cambridge and go there", {30, 31}), CodepointSpan(14, 40)); EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}), CodepointSpan(4, 5)); EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}), CodepointSpan(7, 8)); // With a punctuation around the selected whitespace. EXPECT_EQ( classifier->SuggestSelection( "let's meet at 350 Third Street, Cambridge and go there", {31, 32}), CodepointSpan(14, 41)); // When all's whitespace, should return the original indices. EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}), CodepointSpan(0, 1)); EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}), CodepointSpan(0, 3)); EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}), CodepointSpan(2, 3)); EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}), CodepointSpan(5, 6)); } TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) { UnicodeText text; text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false); EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_), CodepointSpan(3, 4)); text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false); EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_), CodepointSpan(3, 4)); // Nothing on the left. text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false); EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_), CodepointSpan(4, 5)); text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false); EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_), CodepointSpan(0, 1)); // Whitespace only. text = UTF8ToUnicodeText(" ", /*do_copy=*/false); EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, *unilib_), CodepointSpan(2, 3)); text = UTF8ToUnicodeText(" ", /*do_copy=*/false); EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_), CodepointSpan(4, 5)); text = UTF8ToUnicodeText(" ", /*do_copy=*/false); EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_), CodepointSpan(0, 1)); } TEST_F(AnnotatorTest, Annotate) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), })); AnnotationOptions options; EXPECT_THAT(classifier->Annotate("853 225 3556", options), ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")})); EXPECT_THAT(classifier->Annotate("853 225\n3556", options), ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")})); // Try passing invalid utf8. EXPECT_TRUE( classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options) .empty()); } TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today"), ElementsAreArray({ IsAnnotatedSpan(11, 26, "phone"), })); // Unpaired bracket stripping. EXPECT_THAT(classifier->Annotate("call me at (07038201818 today"), ElementsAreArray({ IsAnnotatedSpan(12, 23, "phone"), })); EXPECT_THAT(classifier->Annotate("call me at 07038201818) today"), ElementsAreArray({ IsAnnotatedSpan(11, 22, "phone"), })); EXPECT_THAT(classifier->Annotate("call me at )07038201818( today"), ElementsAreArray({ IsAnnotatedSpan(12, 23, "phone"), })); } TEST_F(AnnotatorTest, AnnotatesWithBracketStrippingOptimized) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.enable_optimization = true; EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today", options), ElementsAreArray({ IsAnnotatedSpan(11, 26, "phone"), })); // Unpaired bracket stripping. EXPECT_THAT(classifier->Annotate("call me at (07038201818 today", options), ElementsAreArray({ IsAnnotatedSpan(12, 23, "phone"), })); EXPECT_THAT(classifier->Annotate("call me at 07038201818) today", options), ElementsAreArray({ IsAnnotatedSpan(11, 22, "phone"), })); EXPECT_THAT(classifier->Annotate("call me at )07038201818( today", options), ElementsAreArray({ IsAnnotatedSpan(12, 23, "phone"), })); } TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; // Number, float number and percentage annotator. EXPECT_THAT( classifier->Annotate("853 225 3556 and then turn it up 99%, 99 " "number, 12345.12345 float number", options), UnorderedElementsAreArray( {IsAnnotatedSpan(0, 12, "phone"), IsAnnotatedSpan(0, 3, "number"), IsAnnotatedSpan(4, 7, "number"), IsAnnotatedSpan(8, 12, "number"), IsAnnotatedSpan(33, 35, "number"), IsAnnotatedSpan(33, 36, "percentage"), IsAnnotatedSpan(38, 40, "number"), IsAnnotatedSpan(49, 60, "number"), IsAnnotatedSpan(49, 60, "phone")})); } TEST_F(AnnotatorTest, DoesNotAnnotateNumbersInSmartUsecase) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART; EXPECT_THAT(classifier->Annotate( "853 225 3556 and then turn it up 99%, 99 number", options), ElementsAreArray({IsAnnotatedSpan(0, 12, "phone"), IsAnnotatedSpan(33, 36, "percentage")})); } void VerifyAnnotatesDurationsInRawMode(const Annotator* classifier) { ASSERT_TRUE(classifier); AnnotationOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; // Duration annotator. EXPECT_THAT(classifier->Annotate( "it took 9 minutes and 7 seconds to get there", options), Contains(IsDurationSpan( /*start=*/8, /*end=*/31, /*duration_ms=*/9 * 60 * 1000 + 7 * 1000))); } TEST_F(AnnotatorTest, AnnotatesDurationsInRawMode) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyAnnotatesDurationsInRawMode(classifier.get()); } void VerifyDurationAndRelativeTimeCanOverlapInRawMode( const Annotator* classifier) { ASSERT_TRUE(classifier); AnnotationOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; options.locales = "en"; const std::vector annotations = classifier->Annotate("let's meet in 3 hours", options); EXPECT_THAT(annotations, Contains(IsDatetimeSpan(/*start=*/11, /*end=*/21, /*time_ms_utc=*/10800000L, DatetimeGranularity::GRANULARITY_HOUR))); EXPECT_THAT(annotations, Contains(IsDurationSpan(/*start=*/14, /*end=*/21, /*duration_ms=*/3 * 60 * 60 * 1000))); } TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawMode) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get()); } TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get()); } TEST_F(AnnotatorTest, AnnotateSplitLines) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->selection_feature_options->only_use_line_with_click = true; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string str1 = "hey, sorry, just finished up. i didn't hear back from you in time."; const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo"; const int kAnnotationLength = 26; EXPECT_THAT(classifier->Annotate(str1), IsEmpty()); EXPECT_THAT( classifier->Annotate(str2), ElementsAreArray({IsAnnotatedSpan(0, kAnnotationLength, "address")})); const std::string str3 = str1 + "\n" + str2; EXPECT_THAT( classifier->Annotate(str3), ElementsAreArray({IsAnnotatedSpan( str1.size() + 1, str1.size() + 1 + kAnnotationLength, "address")})); } TEST_F(AnnotatorTest, UsePipeAsNewLineCharacterShouldAnnotateSplitLines) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->selection_feature_options->only_use_line_with_click = true; model->selection_feature_options->use_pipe_character_for_newline = true; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string str1 = "hey, this is my phone number 853 225 3556"; const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo"; const std::string str3 = str1 + "|" + str2; const int kAnnotationLengthPhone = 12; const int kAnnotationLengthAddress = 26; // Splitting the lines on `str3` should have the same behavior (e.g. find the // phone and address spans) as if we would annotate `str1` and `str2` // individually. const std::vector& annotated_spans = classifier->Annotate(str3); EXPECT_THAT(annotated_spans, ElementsAreArray( {IsAnnotatedSpan(29, 29 + kAnnotationLengthPhone, "phone"), IsAnnotatedSpan(static_cast(str1.size()) + 1, static_cast(str1.size() + 1 + kAnnotationLengthAddress), "address")})); } TEST_F(AnnotatorTest, NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->selection_feature_options->only_use_line_with_click = true; model->selection_feature_options->use_pipe_character_for_newline = false; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string str1 = "hey, this is my phone number 853 225 3556"; const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo"; const std::string str3 = str1 + "|" + str2; const std::vector& annotated_spans = classifier->Annotate(str3); // Note: We only check that we get a single annotated span here when the '|' // character is not used to split lines. The reason behind this is that the // model is not precise for such example and the resulted annotated span might // change when the model changes. EXPECT_THAT(annotated_spans.size(), 1); } TEST_F(AnnotatorTest, AnnotateSmallBatches) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Set the batch size. unpacked_model->selection_options->batch_size = 4; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), })); AnnotationOptions options; EXPECT_THAT(classifier->Annotate("853 225 3556", options), ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")})); EXPECT_THAT(classifier->Annotate("853 225\n3556", options), ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")})); } TEST_F(AnnotatorTest, AnnotateFilteringDiscardAll) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); // Add test threshold. unpacked_model->triggering_options->min_annotate_confidence = 2.f; // Discards all results. flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; EXPECT_EQ(classifier->Annotate(test_string).size(), 0); } TEST_F(AnnotatorTest, AnnotateFilteringKeepAll) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test thresholds. unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); unpacked_model->triggering_options->min_annotate_confidence = 0.f; // Keeps all results. unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; EXPECT_EQ(classifier->Annotate(test_string).size(), 2); } TEST_F(AnnotatorTest, AnnotateDisabled) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Disable the model for annotation. unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; EXPECT_THAT(classifier->Annotate(test_string), IsEmpty()); } TEST_F(AnnotatorTest, AnnotateFilteredCollections) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr classifier = Annotator::FromUnownedBuffer( test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), })); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); unpacked_model->output_options.reset(new OutputOptionsT); // Disable phone annotation unpacked_model->output_options->filtered_collections_annotation.push_back( "phone"); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ IsAnnotatedSpan(28, 55, "address"), })); } TEST_F(AnnotatorTest, AnnotateFilteredCollectionsSuppress) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr classifier = Annotator::FromUnownedBuffer( test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), })); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); unpacked_model->output_options.reset(new OutputOptionsT); // We add a custom annotator that wins against the phone classification // below and that we subsequently suppress. unpacked_model->output_options->filtered_collections_annotation.push_back( "suppress"); unpacked_model->regex_model->patterns.push_back(MakePattern( "suppress", "(\\d{3} ?\\d{4})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0)); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ IsAnnotatedSpan(28, 55, "address"), })); } void VerifyClassifyTextDateInZurichTimezone(const Annotator* classifier) { EXPECT_TRUE(classifier); ClassificationOptions options; options.reference_timezone = "Europe/Zurich"; options.locales = "en"; std::vector result = classifier->ClassifyText("january 1, 2017", {0, 15}, options); EXPECT_THAT(result, ElementsAre(IsDateResult(1483225200000, DatetimeGranularity::GRANULARITY_DAY))); } TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezone) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextDateInZurichTimezone(classifier.get()); } TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezoneWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextDateInZurichTimezone(classifier.get()); } void VerifyClassifyTextDateInLATimezone(const Annotator* classifier) { EXPECT_TRUE(classifier); ClassificationOptions options; options.reference_timezone = "America/Los_Angeles"; options.locales = "en"; std::vector result = classifier->ClassifyText("march 1, 2017", {0, 13}, options); EXPECT_THAT(result, ElementsAre(IsDateResult(1488355200000, DatetimeGranularity::GRANULARITY_DAY))); } TEST_F(AnnotatorTest, ClassifyTextDateInLATimezoneWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextDateInLATimezone(classifier.get()); } TEST_F(AnnotatorTest, ClassifyTextDateInLATimezone) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextDateInLATimezone(classifier.get()); } void VerifyClassifyTextDateOnAotherLine(const Annotator* classifier) { EXPECT_TRUE(classifier); ClassificationOptions options; options.reference_timezone = "Europe/Zurich"; options.locales = "en"; std::vector result = classifier->ClassifyText( "hello world this is the first line\n" "january 1, 2017", {35, 50}, options); EXPECT_THAT(result, ElementsAre(IsDateResult(1483225200000, DatetimeGranularity::GRANULARITY_DAY))); } TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLineWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextDateOnAotherLine(classifier.get()); } TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLine) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextDateOnAotherLine(classifier.get()); } void VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay( const Annotator* classifier) { EXPECT_TRUE(classifier); std::vector result; ClassificationOptions options; options.reference_timezone = "Europe/Zurich"; options.locales = "en-US"; result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options); // In US, the date should be interpreted as .. EXPECT_THAT(result, ElementsAre(IsDatetimeResult( 5439600000, DatetimeGranularity::GRANULARITY_MINUTE))); } TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDay) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get()); } TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get()); } TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); EXPECT_TRUE(classifier); std::vector result; ClassificationOptions options; options.reference_timezone = "Europe/Zurich"; options.locales = "de"; result = classifier->ClassifyText("03.05.1970 00:00vorm", {0, 20}, options); // In Germany, the date should be interpreted as .. EXPECT_THAT(result, ElementsAre(IsDatetimeResult( 10537200000, DatetimeGranularity::GRANULARITY_MINUTE))); } TEST_F(AnnotatorTest, ClassifyTextAmbiguousDatetime) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); EXPECT_TRUE(classifier); ClassificationOptions options; options.reference_timezone = "Europe/Zurich"; options.locales = "en-US"; const std::vector result = classifier->ClassifyText("set an alarm for 10:30", {17, 22}, options); EXPECT_THAT( result, ElementsAre( IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE), IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE))); } TEST_F(AnnotatorTest, AnnotateAmbiguousDatetime) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); EXPECT_TRUE(classifier); AnnotationOptions options; options.reference_timezone = "Europe/Zurich"; options.locales = "en-US"; const std::vector spans = classifier->Annotate("set an alarm for 10:30", options); ASSERT_EQ(spans.size(), 1); const std::vector result = spans[0].classification; EXPECT_THAT( result, ElementsAre( IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE), IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE))); } TEST_F(AnnotatorTest, SuggestTextDateDisabled) { std::string test_model = GetTestModelWithDatetimeRegEx(); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Disable the patterns for selection. for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) { unpacked_model->datetime_model->patterns[i]->enabled_modes = ModeFlag_ANNOTATION_AND_CLASSIFICATION; } flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ("date", FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15}))); EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}), CodepointSpan(0, 7)); EXPECT_THAT(classifier->Annotate("january 1, 2017"), ElementsAreArray({IsAnnotatedSpan(0, 15, "date")})); } TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); // Add test grammar model. unpacked_model->grammar_model.reset(new GrammarModelT); GrammarModelT* grammar_model = unpacked_model->grammar_model.get(); grammar_model->tokenizer_options.reset(new GrammarTokenizerOptionsT); grammar_model->tokenizer_options->tokenization_type = TokenizationType_ICU; grammar_model->tokenizer_options->icu_preserve_whitespace_tokens = false; grammar_model->tokenizer_options->tokenize_on_script_change = true; // Add test rules. grammar_model->rules.reset(new grammar::RulesSetT); grammar::LocaleShardMap locale_shard_map = grammar::LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); rules.Add("", {"jessica", "fletcher"}); rules.Add("", {"columbo"}); rules.Add("", {"magnum"}); rules.Add( "", {""}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/0 /* rule classification result */); // Set result. grammar_model->rule_classification_result.emplace_back( new GrammarModel_::RuleClassificationResultT); GrammarModel_::RuleClassificationResultT* result = grammar_model->rule_classification_result.back().get(); result->collection_name = "famous person"; result->enabled_modes = ModeFlag_ALL; rules.Finalize().Serialize(/*include_debug_information=*/false, grammar_model->rules.get()); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); const std::string test_string = "Did you see the Novel Connection episode where Jessica Fletcher helps " "Magnum solve the case? I thought that was with Columbo ..."; EXPECT_THAT(classifier->Annotate(test_string), ElementsAre(IsAnnotatedSpan(47, 63, "famous person"), IsAnnotatedSpan(70, 76, "famous person"), IsAnnotatedSpan(117, 124, "famous person"))); EXPECT_THAT(FirstResult(classifier->ClassifyText("Jessica Fletcher", CodepointSpan{0, 16})), Eq("famous person")); EXPECT_THAT(classifier->SuggestSelection("Jessica Fletcher", {0, 7}), Eq(CodepointSpan{0, 16})); } TEST_F(AnnotatorTest, ResolveConflictsTrivial) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{ {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0})); } TEST_F(AnnotatorTest, ResolveConflictsSequence) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({0, 1}, "phone", 1.0), MakeAnnotatedSpan({1, 2}, "phone", 1.0), MakeAnnotatedSpan({2, 3}, "phone", 1.0), MakeAnnotatedSpan({3, 4}, "phone", 1.0), MakeAnnotatedSpan({4, 5}, "phone", 1.0), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4})); } TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({0, 3}, "phone", 1.0), MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser! MakeAnnotatedSpan({3, 7}, "phone", 1.0), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0, 2})); } TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser! MakeAnnotatedSpan({1, 5}, "phone", 1.0), MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser! }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({1})); } TEST_F(AnnotatorTest, DoesNotPrioritizeLongerSpanWhenDoingConflictResolution) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({3, 7}, "unit", 1), MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser! MakeAnnotatedSpan({5, 30}, "url", 1), // Looser! MakeAnnotatedSpan({14, 20}, "email", 1), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); // Picks the first and the last annotations because they do not overlap. EXPECT_THAT(chosen, ElementsAreArray({0, 3})); } TEST_F(AnnotatorTest, PrioritizeLongerSpanWhenDoingConflictResolution) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get())); unpacked_model->conflict_resolution_options.reset( new Model_::ConflictResolutionOptionsT); unpacked_model->conflict_resolution_options->prioritize_longest_annotation = true; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = TestingAnnotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); TC3_CHECK(classifier != nullptr); std::vector candidates{{ MakeAnnotatedSpan({3, 7}, "unit", 1), // Looser! MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser! MakeAnnotatedSpan({5, 30}, "url", 1), // Pick longest match. MakeAnnotatedSpan({14, 20}, "email", 1), // Looser! }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART; std::vector chosen; classifier->ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({2})); } TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({0, 3}, "phone", 0.5), MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser! MakeAnnotatedSpan({3, 7}, "phone", 0.6), MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser! MakeAnnotatedSpan({11, 15}, "phone", 0.9), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4})); } TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({0, 15}, "entity", 0.7, AnnotatedSpan::Source::KNOWLEDGE), MakeAnnotatedSpan({5, 10}, "address", 0.6), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0, 1})); } TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({0, 15}, "address", 0.7), MakeAnnotatedSpan({5, 10}, "entity", 0.6, AnnotatedSpan::Source::KNOWLEDGE), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0, 1})); } TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedBothKnowledge) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({0, 15}, "entity", 0.7, AnnotatedSpan::Source::KNOWLEDGE), MakeAnnotatedSpan({5, 10}, "entity", 0.6, AnnotatedSpan::Source::KNOWLEDGE), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0, 1})); } TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsNotAllowed) { TestingAnnotator classifier(unilib_.get(), calendarlib_.get()); std::vector candidates{{ MakeAnnotatedSpan({0, 15}, "address", 0.7), MakeAnnotatedSpan({5, 10}, "date", 0.6), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0})); } TEST_F(AnnotatorTest, ResolveConflictsRawModeGeneralOverlapsAllowed) { TestingAnnotator classifier( unilib_.get(), calendarlib_.get(), [](ModelT* model) { model->conflict_resolution_options.reset( new Model_::ConflictResolutionOptionsT); model->conflict_resolution_options->do_conflict_resolution_in_raw_mode = false; }); std::vector candidates{{ MakeAnnotatedSpan({0, 15}, "address", 0.7), MakeAnnotatedSpan({5, 10}, "date", 0.6), }}; std::vector locales = {Locale::FromBCP47("en")}; BaseOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; std::vector chosen; classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, locales, options, /*interpreter_manager=*/nullptr, &chosen); EXPECT_THAT(chosen, ElementsAreArray({0, 1})); } void VerifyLongInput(const Annotator* classifier) { ASSERT_TRUE(classifier); for (const auto& type_value_pair : std::vector>{ {"address", "350 Third Street, Cambridge"}, {"phone", "123 456-7890"}, {"url", "www.google.com"}, {"email", "someone@gmail.com"}, {"flight", "LX 38"}, {"date", "September 1, 2018"}}) { const std::string input_100k = std::string(50000, ' ') + type_value_pair.second + std::string(50000, ' '); const int value_length = type_value_pair.second.size(); AnnotationOptions annotation_options; annotation_options.locales = "en"; EXPECT_THAT(classifier->Annotate(input_100k, annotation_options), ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length, type_value_pair.first)})); SelectionOptions selection_options; selection_options.locales = "en"; EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}, selection_options), CodepointSpan(50000, 50000 + value_length)); ClassificationOptions classification_options; classification_options.locales = "en"; EXPECT_EQ(type_value_pair.first, FirstResult(classifier->ClassifyText( input_100k, {50000, 50000 + value_length}, classification_options))); } } TEST_F(AnnotatorTest, LongInput) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyLongInput(classifier.get()); } TEST_F(AnnotatorTest, LongInputWithRegExDatetime) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyLongInput(classifier.get()); } // These coarse tests are there only to make sure the execution happens in // reasonable amount of time. TEST_F(AnnotatorTest, LongInputNoResultCheck) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); for (const std::string& value : std::vector{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) { const std::string input_100k = std::string(50000, ' ') + value + std::string(50000, ' '); const int value_length = value.size(); classifier->Annotate(input_100k); classifier->SuggestSelection(input_100k, {50000, 50001}); classifier->ClassifyText(input_100k, {50000, 50000 + value_length}); } } TEST_F(AnnotatorTest, MaxTokenLength) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); std::unique_ptr classifier; // With unrestricted number of tokens should behave normally. unpacked_model->classification_options->max_num_tokens = -1; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText( "I live at 350 Third Street, Cambridge.", {10, 37})), "address"); // Raise the maximum number of tokens to suppress the classification. unpacked_model->classification_options->max_num_tokens = 3; flatbuffers::FlatBufferBuilder builder2; FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder2.GetBufferPointer()), builder2.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText( "I live at 350 Third Street, Cambridge.", {10, 37})), "other"); } TEST_F(AnnotatorTest, MinAddressTokenLength) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); std::unique_ptr classifier; // With unrestricted number of address tokens should behave normally. unpacked_model->classification_options->address_min_num_tokens = 0; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText( "I live at 350 Third Street, Cambridge.", {10, 37})), "address"); // Raise number of address tokens to suppress the address classification. unpacked_model->classification_options->address_min_num_tokens = 5; flatbuffers::FlatBufferBuilder builder2; FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder2.GetBufferPointer()), builder2.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText( "I live at 350 Third Street, Cambridge.", {10, 37})), "other"); } TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighOtherIsPreferredToFlight) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); unpacked_model->triggering_options->other_collection_priority_score = 1.0; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "other"); } TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighFlightIsPreferredToOther) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); unpacked_model->triggering_options->other_collection_priority_score = -100.0; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); std::unique_ptr classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "flight"); } TEST_F(AnnotatorTest, VisitAnnotatorModel) { EXPECT_TRUE( VisitAnnotatorModel(GetTestModelPath(), [](const Model* model) { if (model == nullptr) { return false; } return true; })); EXPECT_FALSE(VisitAnnotatorModel( GetModelPath() + "non_existing_model.fb", [](const Model* model) { if (model == nullptr) { return false; } return true; })); } TEST_F(AnnotatorTest, TriggersWhenNoLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel( model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate("(555) 225-3556"), ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")})); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText("(555) 225-3556", {0, 14}))); EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}), CodepointSpan(0, 14)); } TEST_F(AnnotatorTest, AnnotateTriggersWhenSupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel( model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.detected_text_language_tags = "cs"; EXPECT_THAT(classifier->Annotate("(555) 225-3556", options), ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")})); } TEST_F(AnnotatorTest, AnnotateDoesntTriggerWhenUnsupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel( model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.detected_text_language_tags = "de"; EXPECT_THAT(classifier->Annotate("(555) 225-3556", options), IsEmpty()); } TEST_F(AnnotatorTest, ClassifyTextTriggersWhenSupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel( model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); ClassificationOptions options; options.detected_text_language_tags = "cs"; EXPECT_EQ("phone", FirstResult(classifier->ClassifyText("(555) 225-3556", {0, 14}, options))); } TEST_F(AnnotatorTest, ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel( model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); ClassificationOptions options; options.detected_text_language_tags = "de"; EXPECT_THAT(classifier->ClassifyText("(555) 225-3556", {0, 14}, options), IsEmpty()); } TEST_F(AnnotatorTest, SuggestSelectionTriggersWhenSupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel( model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); SelectionOptions options; options.detected_text_language_tags = "cs"; EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options), CodepointSpan(0, 14)); } TEST_F(AnnotatorTest, SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel( model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); SelectionOptions options; options.detected_text_language_tags = "de"; EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options), CodepointSpan(6, 9)); } TEST_F(AnnotatorTest, MlModelTriggersWhenNoLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; model->triggering_options->locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge"), ElementsAreArray({IsAnnotatedSpan(0, 27, "address")})); EXPECT_EQ("address", FirstResult(classifier->ClassifyText( "350 Third Street, Cambridge", {0, 27}))); EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}), CodepointSpan(0, 27)); } TEST_F(AnnotatorTest, MlModelAnnotateTriggersWhenSupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; model->triggering_options->locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.detected_text_language_tags = "cs"; EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options), ElementsAreArray({IsAnnotatedSpan(0, 27, "address")})); } TEST_F(AnnotatorTest, MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; model->triggering_options->locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.detected_text_language_tags = "de"; EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options), IsEmpty()); } TEST_F(AnnotatorTest, MlModelClassifyTextTriggersWhenSupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; model->triggering_options->locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); ClassificationOptions options; options.detected_text_language_tags = "cs"; EXPECT_EQ("address", FirstResult(classifier->ClassifyText( "350 Third Street, Cambridge", {0, 27}, options))); } TEST_F(AnnotatorTest, MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; model->triggering_options->locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); ClassificationOptions options; options.detected_text_language_tags = "de"; EXPECT_THAT( classifier->ClassifyText("350 Third Street, Cambridge", {0, 27}, options), IsEmpty()); } TEST_F(AnnotatorTest, MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; model->triggering_options->locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); SelectionOptions options; options.detected_text_language_tags = "cs"; EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}, options), CodepointSpan(0, 27)); } TEST_F(AnnotatorTest, MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) { std::string model_buffer = ReadFile(GetTestModelPath()); model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; model->triggering_options->locales = "en,cs"; }); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); SelectionOptions options; options.detected_text_language_tags = "de"; EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}, options), CodepointSpan(4, 9)); } void VerifyClassifyTextOutputsDatetimeEntityData(const Annotator* classifier) { EXPECT_TRUE(classifier); std::vector result; ClassificationOptions options; options.locales = "en-US"; result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options); ASSERT_GE(result.size(), 0); const EntityData* entity_data = GetEntityData(result[0].serialized_entity_data.data()); ASSERT_NE(entity_data, nullptr); ASSERT_NE(entity_data->datetime(), nullptr); EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 5443200000L); EXPECT_EQ(entity_data->datetime()->granularity(), EntityData_::Datetime_::Granularity_GRANULARITY_MINUTE); EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 6); auto* meridiem = entity_data->datetime()->datetime_component()->Get(0); EXPECT_EQ(meridiem->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_MERIDIEM); EXPECT_EQ(meridiem->absolute_value(), 0); EXPECT_EQ(meridiem->relative_count(), 0); EXPECT_EQ(meridiem->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); auto* minute = entity_data->datetime()->datetime_component()->Get(1); EXPECT_EQ(minute->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_MINUTE); EXPECT_EQ(minute->absolute_value(), 0); EXPECT_EQ(minute->relative_count(), 0); EXPECT_EQ(minute->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); auto* hour = entity_data->datetime()->datetime_component()->Get(2); EXPECT_EQ(hour->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_HOUR); EXPECT_EQ(hour->absolute_value(), 0); EXPECT_EQ(hour->relative_count(), 0); EXPECT_EQ(hour->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); auto* day = entity_data->datetime()->datetime_component()->Get(3); EXPECT_EQ( day->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH); EXPECT_EQ(day->absolute_value(), 5); EXPECT_EQ(day->relative_count(), 0); EXPECT_EQ(day->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); auto* month = entity_data->datetime()->datetime_component()->Get(4); EXPECT_EQ(month->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH); EXPECT_EQ(month->absolute_value(), 3); EXPECT_EQ(month->relative_count(), 0); EXPECT_EQ(month->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); auto* year = entity_data->datetime()->datetime_component()->Get(5); EXPECT_EQ(year->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR); EXPECT_EQ(year->absolute_value(), 1970); EXPECT_EQ(year->relative_count(), 0); EXPECT_EQ(year->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); } TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityData) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextOutputsDatetimeEntityData(classifier.get()); } TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyClassifyTextOutputsDatetimeEntityData(classifier.get()); } void VerifyAnnotateOutputsDatetimeEntityData(const Annotator* classifier) { EXPECT_TRUE(classifier); std::vector result; AnnotationOptions options; options.is_serialized_entity_data_enabled = true; options.locales = "en"; result = classifier->Annotate("September 1, 2019", options); ASSERT_GE(result.size(), 0); ASSERT_GE(result[0].classification.size(), 0); ASSERT_EQ(result[0].classification[0].collection, "date"); const EntityData* entity_data = GetEntityData(result[0].classification[0].serialized_entity_data.data()); ASSERT_NE(entity_data, nullptr); ASSERT_NE(entity_data->datetime(), nullptr); EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 1567296000000L); EXPECT_EQ(entity_data->datetime()->granularity(), EntityData_::Datetime_::Granularity_GRANULARITY_DAY); EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 3); auto* day = entity_data->datetime()->datetime_component()->Get(0); EXPECT_EQ( day->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH); EXPECT_EQ(day->absolute_value(), 1); EXPECT_EQ(day->relative_count(), 0); EXPECT_EQ(day->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); auto* month = entity_data->datetime()->datetime_component()->Get(1); EXPECT_EQ(month->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH); EXPECT_EQ(month->absolute_value(), 9); EXPECT_EQ(month->relative_count(), 0); EXPECT_EQ(month->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); auto* year = entity_data->datetime()->datetime_component()->Get(2); EXPECT_EQ(year->component_type(), EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR); EXPECT_EQ(year->absolute_value(), 2019); EXPECT_EQ(year->relative_count(), 0); EXPECT_EQ(year->relation_type(), EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE); } TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityData) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyAnnotateOutputsDatetimeEntityData(classifier.get()); } TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyAnnotateOutputsDatetimeEntityData(classifier.get()); } TEST_F(AnnotatorTest, AnnotateOutputsMoneyEntityData) { // std::string model_buffer = GetTestModelWithDatetimeRegEx(); // std::unique_ptr classifier = // Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), // unilib_.get(), calendarlib_.get()); std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); EXPECT_TRUE(classifier); AnnotationOptions options; options.is_serialized_entity_data_enabled = true; ExpectFirstEntityIsMoney(classifier->Annotate("3.5 CHF", options), "CHF", /*amount=*/"3.5", /*whole_part=*/3, /*decimal_part=*/5, /*nanos=*/500000000); ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.5", options), "CHF", /*amount=*/"3.5", /*whole_part=*/3, /*decimal_part=*/5, /*nanos=*/500000000); ExpectFirstEntityIsMoney( classifier->Annotate("For online purchase of CHF 23.00 enter", options), "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney( classifier->Annotate("For online purchase of 23.00 CHF enter", options), "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("4.8198£", options), "£", /*amount=*/"4.8198", /*whole_part=*/4, /*decimal_part=*/8198, /*nanos=*/819800000); ExpectFirstEntityIsMoney(classifier->Annotate("£4.8198", options), "£", /*amount=*/"4.8198", /*whole_part=*/4, /*decimal_part=*/8198, /*nanos=*/819800000); ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$", /*amount=*/"0.0255", /*whole_part=*/0, /*decimal_part=*/255, /*nanos=*/25500000); ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$", /*amount=*/"0.0255", /*whole_part=*/0, /*decimal_part=*/255, /*nanos=*/25500000); ExpectFirstEntityIsMoney( classifier->Annotate("for txn of INR 000.00 at RAZOR-PAY ZOMATO ONLINE " "OR on card ending 0000.", options), "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney( classifier->Annotate("for txn of 000.00 INR at RAZOR-PAY ZOMATO ONLINE " "OR on card ending 0000.", options), "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("35 CHF", options), "CHF", /*amount=*/"35", /*whole_part=*/35, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("CHF 35", options), "CHF", /*amount=*/"35", /*whole_part=*/35, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney( classifier->Annotate("and win back up to CHF 150 - with digitec", options), "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney( classifier->Annotate("and win back up to 150 CHF - with digitec", options), "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("3.555.333 CHF", options), "CHF", /*amount=*/"3.555.333", /*whole_part=*/3555333, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.555.333", options), "CHF", /*amount=*/"3.555.333", /*whole_part=*/3555333, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("10,000 CHF", options), "CHF", /*amount=*/"10,000", /*whole_part=*/10000, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("CHF 10,000", options), "CHF", /*amount=*/"10,000", /*whole_part=*/10000, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("3,555.33 CHF", options), "CHF", /*amount=*/"3,555.33", /*whole_part=*/3555, /*decimal_part=*/33, /*nanos=*/330000000); ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3,555.33", options), "CHF", /*amount=*/"3,555.33", /*whole_part=*/3555, /*decimal_part=*/33, /*nanos=*/330000000); ExpectFirstEntityIsMoney(classifier->Annotate("$3,000.00", options), "$", /*amount=*/"3,000.00", /*whole_part=*/3000, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("3,000.00$", options), "$", /*amount=*/"3,000.00", /*whole_part=*/3000, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("1.2 CHF", options), "CHF", /*amount=*/"1.2", /*whole_part=*/1, /*decimal_part=*/2, /*nanos=*/200000000); ExpectFirstEntityIsMoney(classifier->Annotate("CHF1.2", options), "CHF", /*amount=*/"1.2", /*whole_part=*/1, /*decimal_part=*/2, /*nanos=*/200000000); ExpectFirstEntityIsMoney(classifier->Annotate("$1.123456789", options), "$", /*amount=*/"1.123456789", /*whole_part=*/1, /*decimal_part=*/123456789, /*nanos=*/123456789); ExpectFirstEntityIsMoney(classifier->Annotate("10.01 CHF", options), "CHF", /*amount=*/"10.01", /*whole_part=*/10, /*decimal_part=*/1, /*nanos=*/10000000); ExpectFirstEntityIsMoney(classifier->Annotate("$59 Million", options), "$", /*amount=*/"59 million", /*whole_part=*/59000000, /*decimal_part=*/0, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("7.05k €", options), "€", /*amount=*/"7.05 k", /*whole_part=*/7050, /*decimal_part=*/5, /*nanos=*/0); ExpectFirstEntityIsMoney(classifier->Annotate("7.123456789m €", options), "€", /*amount=*/"7.123456789 m", /*whole_part=*/7123456, /*decimal_part=*/123456789, /*nanos=*/789000000); ExpectFirstEntityIsMoney(classifier->Annotate("7.000056789k €", options), "€", /*amount=*/"7.000056789 k", /*whole_part=*/7000, /*decimal_part=*/56789, /*nanos=*/56789000); ExpectFirstEntityIsMoney(classifier->Annotate("$59.3 Billion", options), "$", /*amount=*/"59.3 billion", /*whole_part=*/59, /*decimal_part=*/3, /*nanos=*/300000000); ExpectFirstEntityIsMoney(classifier->Annotate("$1.5 Billion", options), "$", /*amount=*/"1.5 billion", /*whole_part=*/1500000000, /*decimal_part=*/5, /*nanos=*/0); } TEST_F(AnnotatorTest, TranslateAction) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); std::unique_ptr langid_model = libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(GetModelPath() + "lang_id.smfb"); classifier->SetLangId(langid_model.get()); ClassificationOptions options; options.user_familiar_language_tags = "de"; std::vector classifications = classifier->ClassifyText("hello, how are you doing?", {11, 14}, options); EXPECT_EQ(classifications.size(), 1); EXPECT_EQ(classifications[0].collection, "translate"); } TEST_F(AnnotatorTest, AnnotateStructuredInputCallsMultipleAnnotators) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); std::vector string_fragments = { {.text = "He owes me 3.5 CHF."}, {.text = "...was born on 13/12/1989."}, }; AnnotationOptions annotation_options; annotation_options.locales = "en"; StatusOr annotations_status = classifier->AnnotateStructuredInput(string_fragments, annotation_options); ASSERT_TRUE(annotations_status.ok()); Annotations annotations = annotations_status.ValueOrDie(); ASSERT_EQ(annotations.annotated_spans.size(), 2); EXPECT_THAT(annotations.annotated_spans[0], ElementsAreArray({IsAnnotatedSpan(11, 18, "money")})); EXPECT_THAT(annotations.annotated_spans[1], ElementsAreArray({IsAnnotatedSpan(15, 25, "date")})); } void VerifyInputFragmentTimestampOverridesAnnotationOptions( const Annotator* classifier) { AnnotationOptions annotation_options; annotation_options.locales = "en"; annotation_options.reference_time_ms_utc = 1554465190000; // 04/05/2019 11:53 am int64 fragment_reference_time = 946727580000; // 01/01/2000 11:53 am std::vector string_fragments = { {.text = "New event at 17:20"}, { .text = "New event at 17:20", .datetime_options = Optional( {.reference_time_ms_utc = fragment_reference_time}), }}; StatusOr annotations_status = classifier->AnnotateStructuredInput(string_fragments, annotation_options); ASSERT_TRUE(annotations_status.ok()); Annotations annotations = annotations_status.ValueOrDie(); ASSERT_EQ(annotations.annotated_spans.size(), 2); EXPECT_THAT(annotations.annotated_spans[0], ElementsAreArray({IsDatetimeSpan( /*start=*/13, /*end=*/18, /*time_ms_utc=*/1554484800000, DatetimeGranularity::GRANULARITY_MINUTE)})); EXPECT_THAT(annotations.annotated_spans[1], ElementsAreArray({IsDatetimeSpan( /*start=*/13, /*end=*/18, /*time_ms_utc=*/946747200000, DatetimeGranularity::GRANULARITY_MINUTE)})); } TEST_F(AnnotatorTest, InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get()); } TEST_F(AnnotatorTest, InputFragmentTimestampOverridesAnnotationOptions) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get()); } void VerifyInputFragmentTimezoneOverridesAnnotationOptions( const Annotator* classifier) { std::vector string_fragments = { {.text = "11/12/2020 17:20"}, { .text = "11/12/2020 17:20", .datetime_options = Optional( {.reference_timezone = "Europe/Zurich"}), }}; AnnotationOptions annotation_options; annotation_options.locales = "en-US"; StatusOr annotations_status = classifier->AnnotateStructuredInput(string_fragments, annotation_options); ASSERT_TRUE(annotations_status.ok()); Annotations annotations = annotations_status.ValueOrDie(); ASSERT_EQ(annotations.annotated_spans.size(), 2); EXPECT_THAT(annotations.annotated_spans[0], ElementsAreArray({IsDatetimeSpan( /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605201600000, DatetimeGranularity::GRANULARITY_MINUTE)})); EXPECT_THAT(annotations.annotated_spans[1], ElementsAreArray({IsDatetimeSpan( /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605198000000, DatetimeGranularity::GRANULARITY_MINUTE)})); } TEST_F(AnnotatorTest, InputFragmentTimezoneOverridesAnnotationOptions) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get()); } TEST_F(AnnotatorTest, InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get()); } namespace { void AddDummyRegexDatetimeModel(ModelT* unpacked_model) { unpacked_model->datetime_model.reset(new DatetimeModelT); // This needs to be false otherwise we'd have to define some extractor. When // this is false, the 0-th capturing group (whole match) from the pattern is // used to come up with the indices. unpacked_model->datetime_model->use_extractors_for_locating = false; unpacked_model->datetime_model->locales.push_back("en-US"); unpacked_model->datetime_model->default_locales.push_back(0); // en-US unpacked_model->datetime_model->patterns.push_back( std::unique_ptr(new DatetimeModelPatternT)); unpacked_model->datetime_model->patterns.back()->locales.push_back( 0); // en-US unpacked_model->datetime_model->patterns.back()->regexes.push_back( std::unique_ptr( new DatetimeModelPattern_::RegexT)); unpacked_model->datetime_model->patterns.back()->regexes.back()->pattern = "THIS_MATCHES_IN_REGEX_MODEL"; unpacked_model->datetime_model->patterns.back() ->regexes.back() ->groups.push_back(DatetimeGroupType_GROUP_UNUSED); } } // namespace TEST_F(AnnotatorTest, AnnotateFiltersOutExactDuplicates) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); // This test assumes that both ML model and Regex model trigger on the // following text and output "phone" annotation for it. const std::string test_string = "1000000000"; AnnotationOptions options; options.annotation_usecase = ANNOTATION_USECASE_RAW; int num_phones = 0; for (const AnnotatedSpan& span : classifier->Annotate(test_string, options)) { if (span.classification[0].collection == "phone") { num_phones++; } } EXPECT_EQ(num_phones, 1); } // This test tests the optimizations in Annotator, which make some of the // annotators not run in the RAW mode when not requested. We test here that the // results indeed don't contain such annotations. However, this is a bick hacky, // since one could also add post-filtering, in which case these tests would // trivially pass. TEST_F(AnnotatorTest, RawModeOptimizationWorks) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; // Requesting a non-existing type to avoid overlap with existing types. options.entity_types.insert("some_unknown_entity_type"); // Normally, the following command would produce the following annotations: // Span(19, 24, date, 1.000000), // Span(53, 56, number, 1.000000), // Span(53, 80, address, 1.000000), // Span(128, 142, phone, 1.000000), // Span(129, 132, number, 1.000000), // Span(192, 200, phone, 1.000000), // Span(192, 206, datetime, 1.000000), // Span(246, 253, number, 1.000000), // Span(246, 253, phone, 1.000000), // Span(292, 293, number, 1.000000), // Span(292, 301, duration, 1.000000) } // But because of the optimizations, it doesn't produce anything, since // we didn't request any of these entities. EXPECT_THAT(classifier->Annotate(R"--(I saw Barack Obama today 350 Third Street, Cambridge my phone number is (853) 225-3556 this is when we met: 1.9.2021 13:00 my number: 1234567 duration: 3 minutes )--", options), IsEmpty()); } void VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode( const Annotator* classifier) { ASSERT_TRUE(classifier); struct Example { std::string collection; std::string text; }; // These examples contain one example per annotator, to check that each of // the annotators can work in the RAW mode on its own. // // WARNING: This list doesn't contain yet entries for the app, contact, and // person annotators. Hopefully this won't be needed once b/155214735 is // fixed and the piping shared across annotators. std::vector examples{ // ML Model. {.collection = Collections::Address(), .text = "... 350 Third Street, Cambridge ..."}, // Datetime annotator. {.collection = Collections::DateTime(), .text = "... 1.9.2020 10:00 ..."}, // Duration annotator. {.collection = Collections::Duration(), .text = "... 3 hours and 9 seconds ..."}, // Regex annotator. {.collection = Collections::Email(), .text = "... platypus@theanimal.org ..."}, // Number annotator. {.collection = Collections::Number(), .text = "... 100 ..."}, }; for (const Example& example : examples) { AnnotationOptions options; options.locales = "en"; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; options.entity_types.insert(example.collection); EXPECT_THAT(classifier->Annotate(example.text, options), Contains(IsAnnotationWithType(example.collection))) << " text: '" << example.text << "', collection: " << example.collection; } } TEST_F(AnnotatorTest, AnnotateSupportsPointwiseCollectionFilteringInRawMode) { std::unique_ptr classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get()); } TEST_F(AnnotatorTest, AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx) { std::string model_buffer = GetTestModelWithDatetimeRegEx(); std::unique_ptr classifier = Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(), unilib_.get(), calendarlib_.get()); VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get()); } TEST_F(AnnotatorTest, InitializeFromString) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr classifier = Annotator::FromString(test_model, unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate("(857) 225-3556"), Not(IsEmpty())); } // Regression test for cl/338280366. Enabling only_use_line_with_click had // the effect, that some annotators in the previous code releases would // receive only the last line of the input text. This test has the entity on the // first line (duration). TEST_F(AnnotatorTest, RegressionTestOnlyUseLineWithClickLastLine) { const std::string test_model = ReadFile(GetTestModelPath()); std::unique_ptr unpacked_model = UnPackModel(test_model.c_str()); std::unique_ptr classifier; // With unrestricted number of tokens should behave normally. unpacked_model->selection_feature_options->only_use_line_with_click = true; flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); classifier = Annotator::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); AnnotationOptions options; options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW; const std::vector annotations = classifier->Annotate("let's meet in 3 hours\nbut not now", options); EXPECT_THAT(annotations, Contains(IsDurationSpan( /*start=*/14, /*end=*/21, /*duration_ms=*/3 * 60 * 60 * 1000))); } TEST_F(AnnotatorTest, DoesntProcessInvalidUtf8) { const std::string test_model = ReadFile(GetTestModelPath()); const std::string invalid_utf8_text_with_phone_number = "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80"; std::unique_ptr classifier = Annotator::FromString(test_model, unilib_.get(), calendarlib_.get()); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate(invalid_utf8_text_with_phone_number), IsEmpty()); EXPECT_THAT( classifier->SuggestSelection(invalid_utf8_text_with_phone_number, {1, 4}), Eq(CodepointSpan{1, 4})); EXPECT_THAT( classifier->ClassifyText(invalid_utf8_text_with_phone_number, {0, 14}), IsEmpty()); } } // namespace test_internal } // namespace libtextclassifier3