/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/actions-suggestions.h" #include #include #include #include #include "actions/actions_model_generated.h" #include "actions/test-utils.h" #include "actions/zlib-utils.h" #include "annotator/collections.h" #include "annotator/types.h" #include "utils/flatbuffers/flatbuffers.h" #include "utils/flatbuffers/flatbuffers_generated.h" #include "utils/flatbuffers/mutable.h" #include "utils/grammar/utils/locale-shard-map.h" #include "utils/grammar/utils/rules.h" #include "utils/hash/farmhash.h" #include "utils/jvm-test-utils.h" #include "utils/test-data-test-utils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "flatbuffers/flatbuffers.h" #include "flatbuffers/reflection.h" namespace libtextclassifier3 { namespace { using ::testing::ElementsAre; using ::testing::FloatEq; using ::testing::IsEmpty; using ::testing::NotNull; using ::testing::SizeIs; constexpr char kModelFileName[] = "actions_suggestions_test.model"; constexpr char kModelGrammarFileName[] = "actions_suggestions_grammar_test.model"; constexpr char kMultiTaskTF2TestModelFileName[] = "actions_suggestions_test.multi_task_tf2_test.model"; constexpr char kMultiTaskModelFileName[] = "actions_suggestions_test.multi_task_9heads.model"; constexpr char kHashGramModelFileName[] = "actions_suggestions_test.hashgram.model"; constexpr char kMultiTaskSrP13nModelFileName[] = "actions_suggestions_test.multi_task_sr_p13n.model"; constexpr char kMultiTaskSrEmojiModelFileName[] = "actions_suggestions_test.multi_task_sr_emoji.model"; constexpr char kSensitiveTFliteModelFileName[] = "actions_suggestions_test.sensitive_tflite.model"; std::string ReadFile(const std::string& file_name) { std::ifstream file_stream(file_name); return std::string(std::istreambuf_iterator(file_stream), {}); } std::string GetModelPath() { return GetTestDataPath("actions/test_data/"); } class ActionsSuggestionsTest : public testing::Test { protected: explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {} std::unique_ptr LoadTestModel( const std::string model_file_name) { return ActionsSuggestions::FromPath(GetModelPath() + model_file_name, unilib_.get()); } std::unique_ptr LoadHashGramTestModel() { return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName, unilib_.get()); } std::unique_ptr LoadMultiTaskTestModel() { return ActionsSuggestions::FromPath( GetModelPath() + kMultiTaskModelFileName, unilib_.get()); } std::unique_ptr LoadMultiTaskSrP13nTestModel() { return ActionsSuggestions::FromPath( GetModelPath() + kMultiTaskSrP13nModelFileName, unilib_.get()); } std::unique_ptr unilib_; }; TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) { EXPECT_THAT(LoadTestModel(kModelFileName), NotNull()); } TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) { std::unique_ptr actions_suggestions = LoadTestModel(kModelFileName); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Where are you?\xf0\x9f", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_THAT(response.actions, IsEmpty()); } TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) { std::unique_ptr actions_suggestions = LoadTestModel(kModelFileName); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_THAT(response.actions, IsEmpty()); } TEST_F(ActionsSuggestionsTest, SuggestsActions) { std::unique_ptr actions_suggestions = LoadTestModel(kModelFileName); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/); } TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) { std::unique_ptr actions_suggestions = LoadTestModel(kModelFileName); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"zz"}}}); EXPECT_THAT(response.actions, testing::IsEmpty()); } TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) { std::unique_ptr actions_suggestions = LoadTestModel(kModelFileName); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = {ClassificationResult("address", 1.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "are you at home?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions.front().type, "view_map"); EXPECT_EQ(response.actions.front().score, 1.0); } TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithEntityData) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); SetTestEntityDataSchema(actions_model.get()); // Set custom actions from annotations config. actions_model->annotation_actions_spec->annotation_mapping.clear(); actions_model->annotation_actions_spec->annotation_mapping.emplace_back( new AnnotationActionsSpec_::AnnotationMappingT); AnnotationActionsSpec_::AnnotationMappingT* mapping = actions_model->annotation_actions_spec->annotation_mapping.back().get(); mapping->annotation_collection = "address"; mapping->action.reset(new ActionSuggestionSpecT); mapping->action->type = "save_location"; mapping->action->score = 1.0; mapping->action->priority_score = 2.0; mapping->entity_field.reset(new FlatbufferFieldPathT); mapping->entity_field->field.emplace_back(new FlatbufferFieldT); mapping->entity_field->field.back()->field_name = "location"; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = {ClassificationResult("address", 1.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "are you at home?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions.front().type, "save_location"); EXPECT_EQ(response.actions.front().score, 1.0); // Check that the `location` entity field holds the text from the address // annotation. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( response.actions.front().serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/6)->str(), "home"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithNormalization) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); SetTestEntityDataSchema(actions_model.get()); // Set custom actions from annotations config. actions_model->annotation_actions_spec->annotation_mapping.clear(); actions_model->annotation_actions_spec->annotation_mapping.emplace_back( new AnnotationActionsSpec_::AnnotationMappingT); AnnotationActionsSpec_::AnnotationMappingT* mapping = actions_model->annotation_actions_spec->annotation_mapping.back().get(); mapping->annotation_collection = "address"; mapping->action.reset(new ActionSuggestionSpecT); mapping->action->type = "save_location"; mapping->action->score = 1.0; mapping->action->priority_score = 2.0; mapping->entity_field.reset(new FlatbufferFieldPathT); mapping->entity_field->field.emplace_back(new FlatbufferFieldT); mapping->entity_field->field.back()->field_name = "location"; mapping->normalization_options.reset(new NormalizationOptionsT); mapping->normalization_options->codepointwise_normalization = NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = {ClassificationResult("address", 1.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "are you at home?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions.front().type, "save_location"); EXPECT_EQ(response.actions.front().score, 1.0); // Check that the `location` entity field holds the normalized text of the // annotation. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( response.actions.front().serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/6)->str(), "HOME"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) { std::unique_ptr actions_suggestions = LoadTestModel(kModelFileName); AnnotatedSpan flight_annotation; flight_annotation.span = {11, 15}; flight_annotation.classification = {ClassificationResult("flight", 2.5)}; AnnotatedSpan flight_annotation2; flight_annotation2.span = {35, 39}; flight_annotation2.classification = {ClassificationResult("flight", 3.0)}; AnnotatedSpan email_annotation; email_annotation.span = {43, 56}; email_annotation.classification = {ClassificationResult("email", 2.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "call me at LX38 or send message to LX38 or test@test.com.", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {flight_annotation, flight_annotation2, email_annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 2); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[0].score, 3.0); EXPECT_EQ(response.actions[1].type, "send_email"); EXPECT_EQ(response.actions[1].score, 2.0); } TEST_F(ActionsSuggestionsTest, SuggestsActionsAnnotationsWithNoDeduplication) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Disable deduplication. actions_model->annotation_actions_spec->deduplicate_annotations = false; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); AnnotatedSpan flight_annotation; flight_annotation.span = {11, 15}; flight_annotation.classification = {ClassificationResult("flight", 2.5)}; AnnotatedSpan flight_annotation2; flight_annotation2.span = {35, 39}; flight_annotation2.classification = {ClassificationResult("flight", 3.0)}; AnnotatedSpan email_annotation; email_annotation.span = {43, 56}; email_annotation.classification = {ClassificationResult("email", 2.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "call me at LX38 or send message to LX38 or test@test.com.", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {flight_annotation, flight_annotation2, email_annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 3); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[0].score, 3.0); EXPECT_EQ(response.actions[1].type, "track_flight"); EXPECT_EQ(response.actions[1].score, 2.5); EXPECT_EQ(response.actions[2].type, "send_email"); EXPECT_EQ(response.actions[2].score, 2.0); } ActionsSuggestionsResponse TestSuggestActionsFromAnnotations( const std::function& set_config_fn, const UniLib* unilib = nullptr) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Set custom config. set_config_fn(actions_model.get()); // Disable smart reply for easier testing. actions_model->preconditions->min_smart_reply_triggering_score = 1.0; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib); AnnotatedSpan flight_annotation; flight_annotation.span = {15, 19}; flight_annotation.classification = {ClassificationResult("flight", 2.0)}; AnnotatedSpan email_annotation; email_annotation.span = {0, 16}; email_annotation.classification = {ClassificationResult("email", 1.0)}; return actions_suggestions->SuggestActions( {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hehe@android.com", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {email_annotation}, /*locales=*/"en"}, {/*user_id=*/2, "yoyo@android.com", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {email_annotation}, /*locales=*/"en"}, {/*user_id=*/1, "test@android.com", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {email_annotation}, /*locales=*/"en"}, {/*user_id=*/1, "I am on flight LX38.", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {flight_annotation}, /*locales=*/"en"}}}); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastMessage) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 1; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, unilib_.get()); EXPECT_THAT(response.actions, SizeIs(1)); EXPECT_EQ(response.actions[0].type, "track_flight"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastPerson) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 1; actions_model->annotation_actions_spec->max_history_from_last_person = 3; }, unilib_.get()); EXPECT_THAT(response.actions, SizeIs(2)); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAny) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 2; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, unilib_.get()); EXPECT_THAT(response.actions, SizeIs(2)); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAnyManyMessages) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 3; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, unilib_.get()); EXPECT_THAT(response.actions, SizeIs(3)); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); EXPECT_EQ(response.actions[2].type, "send_email"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 5; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, unilib_.get()); EXPECT_THAT(response.actions, SizeIs(3)); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); EXPECT_EQ(response.actions[2].type, "send_email"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = true; actions_model->annotation_actions_spec->only_until_last_sent = false; actions_model->annotation_actions_spec->max_history_from_any_person = 5; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, unilib_.get()); EXPECT_THAT(response.actions, SizeIs(4)); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); EXPECT_EQ(response.actions[2].type, "send_email"); EXPECT_EQ(response.actions[3].type, "send_email"); } void TestSuggestActionsWithThreshold( const std::function& set_value_fn, const UniLib* unilib = nullptr, const int expected_size = 0, const std::string& preconditions_overwrite = "") { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); set_value_fn(actions_model.get()); flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib, preconditions_overwrite); ASSERT_TRUE(actions_suggestions); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "I have the low-ground. Where are you?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_LE(response.actions.size(), expected_size); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithTriggeringScore) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->min_smart_reply_triggering_score = 1.0; }, unilib_.get(), /*expected_size=*/1 /*no smart reply, only actions*/ ); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinReplyScore) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->min_reply_score_threshold = 1.0; }, unilib_.get(), /*expected_size=*/1 /*no smart reply, only actions*/ ); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithSensitiveTopicScore) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->max_sensitive_topic_score = 0.0; }, unilib_.get(), /*expected_size=*/4 /* no sensitive prediction in test model*/); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMaxInputLength) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->max_input_length = 0; }, unilib_.get()); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->min_input_length = 100; }, unilib_.get()); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) { TriggeringPreconditionsT preconditions_overwrite; preconditions_overwrite.max_input_length = 0; flatbuffers::FlatBufferBuilder builder; builder.Finish( TriggeringPreconditions::Pack(builder, &preconditions_overwrite)); TestSuggestActionsWithThreshold( // Keep model untouched. [](ActionsModelT* actions_model) {}, unilib_.get(), /*expected_size=*/0, std::string(reinterpret_cast(builder.GetBufferPointer()), builder.GetSize())); } #ifdef TC3_UNILIB_ICU TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidence) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->suppress_on_low_confidence_input = true; actions_model->low_confidence_rules.reset(new RulesModelT); actions_model->low_confidence_rules->regex_rule.emplace_back( new RulesModel_::RegexRuleT); actions_model->low_confidence_rules->regex_rule.back()->pattern = "low-ground"; }, unilib_.get()); } TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Add custom triggering rule. actions_model->rules.reset(new RulesModelT()); actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT); RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get(); rule->pattern = "^(?i:hello\\s(there))$"; { std::unique_ptr rule_action( new RulesModel_::RuleActionSpecT); rule_action->action.reset(new ActionSuggestionSpecT); rule_action->action->type = "text_reply"; rule_action->action->response_text = "General Desaster!"; rule_action->action->score = 1.0f; rule_action->action->priority_score = 1.0f; rule->actions.push_back(std::move(rule_action)); } { std::unique_ptr rule_action( new RulesModel_::RuleActionSpecT); rule_action->action.reset(new ActionSuggestionSpecT); rule_action->action->type = "text_reply"; rule_action->action->response_text = "General Kenobi!"; rule_action->action->score = 1.0f; rule_action->action->priority_score = 1.0f; rule->actions.push_back(std::move(rule_action)); } // Add input-output low confidence rule. actions_model->preconditions->suppress_on_low_confidence_input = true; actions_model->low_confidence_rules.reset(new RulesModelT); actions_model->low_confidence_rules->regex_rule.emplace_back( new RulesModel_::RegexRuleT); actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello"; actions_model->low_confidence_rules->regex_rule.back()->output_pattern = "(?i:desaster)"; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); ASSERT_TRUE(actions_suggestions); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutputOverwrite) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); actions_model->low_confidence_rules.reset(); // Add custom triggering rule. actions_model->rules.reset(new RulesModelT()); actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT); RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get(); rule->pattern = "^(?i:hello\\s(there))$"; { std::unique_ptr rule_action( new RulesModel_::RuleActionSpecT); rule_action->action.reset(new ActionSuggestionSpecT); rule_action->action->type = "text_reply"; rule_action->action->response_text = "General Desaster!"; rule_action->action->score = 1.0f; rule_action->action->priority_score = 1.0f; rule->actions.push_back(std::move(rule_action)); } { std::unique_ptr rule_action( new RulesModel_::RuleActionSpecT); rule_action->action.reset(new ActionSuggestionSpecT); rule_action->action->type = "text_reply"; rule_action->action->response_text = "General Kenobi!"; rule_action->action->score = 1.0f; rule_action->action->priority_score = 1.0f; rule->actions.push_back(std::move(rule_action)); } // Add custom triggering rule via overwrite. actions_model->preconditions->low_confidence_rules.reset(); TriggeringPreconditionsT preconditions; preconditions.suppress_on_low_confidence_input = true; preconditions.low_confidence_rules.reset(new RulesModelT); preconditions.low_confidence_rules->regex_rule.emplace_back( new RulesModel_::RegexRuleT); preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello"; preconditions.low_confidence_rules->regex_rule.back()->output_pattern = "(?i:desaster)"; flatbuffers::FlatBufferBuilder preconditions_builder; preconditions_builder.Finish( TriggeringPreconditions::Pack(preconditions_builder, &preconditions)); std::string serialize_preconditions = std::string( reinterpret_cast(preconditions_builder.GetBufferPointer()), preconditions_builder.GetSize()); flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get(), serialize_preconditions); ASSERT_TRUE(actions_suggestions); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); } #endif TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Don't test if no sensitivity score is produced if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) { return; } actions_model->preconditions->max_sensitive_topic_score = 0.0; actions_model->preconditions->suppress_on_sensitive_topic = true; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = { ClassificationResult(Collections::Address(), 1.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "are you at home?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); EXPECT_THAT(response.actions, testing::IsEmpty()); } TEST_F(ActionsSuggestionsTest, SuggestsActionsWithLongerConversation) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Allow a larger conversation context. actions_model->max_conversation_history_length = 10; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = { ClassificationResult(Collections::Address(), 1.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?", /*reference_time_ms_utc=*/10000, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}, {/*user_id=*/1, "good! are you at home?", /*reference_time_ms_utc=*/15000, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].type, "view_map"); EXPECT_EQ(response.actions[0].score, 1.0); } TEST_F(ActionsSuggestionsTest, SuggestsActionsFromTF2MultiTaskModel) { std::unique_ptr actions_suggestions = LoadTestModel(kMultiTaskTF2TestModelFileName); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Hello how are you", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_EQ(response.actions.size(), 4); EXPECT_EQ(response.actions[0].response_text, "Okay"); EXPECT_EQ(response.actions[0].type, "REPLY_SUGGESTION"); EXPECT_EQ(response.actions[3].type, "TEST_CLASSIFIER_INTENT"); } TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) { std::unique_ptr actions_suggestions = LoadTestModel(kModelGrammarFileName); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = {ClassificationResult("phone", 0.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Contact us at: *1234", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions.front().type, "call_phone"); EXPECT_EQ(response.actions.front().score, 0.0); EXPECT_EQ(response.actions.front().priority_score, 0.0); EXPECT_EQ(response.actions.front().annotations.size(), 1); EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15); EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20); } TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) { std::unique_ptr actions_suggestions = LoadTestModel(kModelFileName); AnnotatedSpan annotation; annotation.span = {8, 12}; annotation.classification = { ClassificationResult(Collections::Flight(), 1.0)}; const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "I'm on LX38?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 2); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[0].score, 1.0); EXPECT_THAT(response.actions[0].annotations, SizeIs(1)); EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0); EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span); } #ifdef TC3_UNILIB_ICU TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules.reset(new RulesModelT()); actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT); RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get(); rule->pattern = "^(?i:hello\\s(there))$"; rule->actions.emplace_back(new RulesModel_::RuleActionSpecT); rule->actions.back()->action.reset(new ActionSuggestionSpecT); ActionSuggestionSpecT* action = rule->actions.back()->action.get(); action->type = "text_reply"; action->response_text = "General Kenobi!"; action->score = 1.0f; action->priority_score = 1.0f; // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group = rule->actions.back()->capturing_group.back().get(); greeting_group->group_id = 0; greeting_group->entity_field.reset(new FlatbufferFieldPathT); greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT); greeting_group->entity_field->field.back()->field_name = "greeting"; rule->actions.back()->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group = rule->actions.back()->capturing_group.back().get(); location_group->group_id = 1; location_group->entity_field.reset(new FlatbufferFieldPathT); location_group->entity_field->field.emplace_back(new FlatbufferFieldT); location_group->entity_field->field.back()->field_name = "location"; // Set test entity data schema. SetTestEntityDataSchema(actions_model.get()); // Use meta data to generate custom serialized entity data. MutableFlatbufferBuilder entity_data_builder( flatbuffers::GetRoot( actions_model->actions_entity_data_schema.data())); std::unique_ptr entity_data = entity_data_builder.NewRoot(); entity_data->Set("person", "Kenobi"); action->serialized_entity_data = entity_data->Serialize(); flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( response.actions[0].serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/4)->str(), "hello there"); EXPECT_EQ(entity->GetPointer(/*field=*/6)->str(), "there"); EXPECT_EQ(entity->GetPointer(/*field=*/8)->str(), "Kenobi"); } TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules.reset(new RulesModelT()); actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT); RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get(); rule->pattern = "^(?i:hello\\sthere)$"; rule->actions.emplace_back(new RulesModel_::RuleActionSpecT); rule->actions.back()->action.reset(new ActionSuggestionSpecT); ActionSuggestionSpecT* action = rule->actions.back()->action.get(); action->type = "text_reply"; action->response_text = "General Kenobi!"; action->score = 1.0f; action->priority_score = 1.0f; // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group = rule->actions.back()->capturing_group.back().get(); greeting_group->group_id = 0; greeting_group->entity_field.reset(new FlatbufferFieldPathT); greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT); greeting_group->entity_field->field.back()->field_name = "greeting"; greeting_group->normalization_options.reset(new NormalizationOptionsT); greeting_group->normalization_options->codepointwise_normalization = NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE | NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE; // Set test entity data schema. SetTestEntityDataSchema(actions_model.get()); flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( response.actions[0].serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/4)->str(), "HELLOTHERE"); } TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules.reset(new RulesModelT()); actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT); RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get(); rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )"; rule->actions.emplace_back(new RulesModel_::RuleActionSpecT); // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group = rule->actions.back()->capturing_group.back().get(); code_group->group_id = 1; code_group->text_reply.reset(new ActionSuggestionSpecT); code_group->text_reply->score = 1.0f; code_group->text_reply->priority_score = 1.0f; code_group->normalization_options.reset(new NormalizationOptionsT); code_group->normalization_options->codepointwise_normalization = NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "visit test.com or reply STOP to cancel your subscription", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "stop"); } TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT); // Set tokenizer options. RulesModel_::GrammarRulesT* action_grammar_rules = actions_model->rules->grammar_rules.get(); action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT); action_grammar_rules->tokenizer_options->type = TokenizationType_ICU; action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens = false; // Setup test rules. action_grammar_rules->rules.reset(new grammar::RulesSetT); grammar::LocaleShardMap locale_shard_map = grammar::LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); rules.Add( "", {"<^>", "ventura", "!?", "<$>"}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/0); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules->rules.get()); action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT); RulesModel_::RuleActionSpecT* actions_spec = action_grammar_rules->actions.back().get(); actions_spec->action.reset(new ActionSuggestionSpecT); actions_spec->action->response_text = "Yes, Satan?"; actions_spec->action->priority_score = 1.0; actions_spec->action->score = 1.0; actions_spec->action->type = "text_reply"; action_grammar_rules->rule_match.emplace_back( new RulesModel_::GrammarRules_::RuleMatchT); action_grammar_rules->rule_match.back()->action_id.push_back(0); flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib_.get()); const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Ventura!", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?"))); } #if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME) TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) { std::unique_ptr annotator = Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get()); const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT); // Set tokenizer options. RulesModel_::GrammarRulesT* action_grammar_rules = actions_model->rules->grammar_rules.get(); action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT); action_grammar_rules->tokenizer_options->type = TokenizationType_ICU; action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens = false; // Setup test rules. action_grammar_rules->rules.reset(new grammar::RulesSetT); grammar::LocaleShardMap locale_shard_map = grammar::LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); rules.Add( "", {"it", "is", "at", "