/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/grammar-actions.h" #include #include #include "actions/actions_model_generated.h" #include "actions/test-utils.h" #include "actions/types.h" #include "utils/flatbuffers/flatbuffers.h" #include "utils/flatbuffers/mutable.h" #include "utils/grammar/rules_generated.h" #include "utils/grammar/types.h" #include "utils/grammar/utils/rules.h" #include "utils/jvm-test-utils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace libtextclassifier3 { namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; using ::libtextclassifier3::grammar::LocaleShardMap; class TestGrammarActions : public GrammarActions { public: explicit TestGrammarActions( const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules, const MutableFlatbufferBuilder* entity_data_builder = nullptr) : GrammarActions(unilib, grammar_rules, entity_data_builder, /*smart_reply_action_type=*/"text_reply") {} }; class GrammarActionsTest : public testing::Test { protected: struct AnnotationSpec { int group_id = 0; std::string annotation_name = ""; bool use_annotation_match = false; }; GrammarActionsTest() : unilib_(CreateUniLibForTesting()), serialized_entity_data_schema_(TestEntityDataSchema()), entity_data_builder_(new MutableFlatbufferBuilder( flatbuffers::GetRoot( serialized_entity_data_schema_.data()))) {} void SetTokenizerOptions( RulesModel_::GrammarRulesT* action_grammar_rules) const { action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT); action_grammar_rules->tokenizer_options->type = TokenizationType_ICU; action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens = false; } int AddActionSpec(const std::string& type, const std::string& response_text, const std::vector& annotations, RulesModel_::GrammarRulesT* action_grammar_rules) const { const int action_id = action_grammar_rules->actions.size(); action_grammar_rules->actions.emplace_back( new RulesModel_::RuleActionSpecT); RulesModel_::RuleActionSpecT* actions_spec = action_grammar_rules->actions.back().get(); actions_spec->action.reset(new ActionSuggestionSpecT); actions_spec->action->response_text = response_text; actions_spec->action->priority_score = 1.0; actions_spec->action->score = 1.0; actions_spec->action->type = type; // Create annotations for specified capturing groups. for (const AnnotationSpec& annotation : annotations) { actions_spec->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); actions_spec->capturing_group.back()->group_id = annotation.group_id; actions_spec->capturing_group.back()->annotation_name = annotation.annotation_name; actions_spec->capturing_group.back()->annotation_type = annotation.annotation_name; actions_spec->capturing_group.back()->use_annotation_match = annotation.use_annotation_match; } return action_id; } int AddSmartReplySpec( const std::string& response_text, RulesModel_::GrammarRulesT* action_grammar_rules) const { return AddActionSpec("text_reply", response_text, {}, action_grammar_rules); } int AddCapturingMatchSmartReplySpec( const int match_id, RulesModel_::GrammarRulesT* action_grammar_rules) const { const int action_id = action_grammar_rules->actions.size(); action_grammar_rules->actions.emplace_back( new RulesModel_::RuleActionSpecT); RulesModel_::RuleActionSpecT* actions_spec = action_grammar_rules->actions.back().get(); actions_spec->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); actions_spec->capturing_group.back()->group_id = match_id; actions_spec->capturing_group.back()->text_reply.reset( new ActionSuggestionSpecT); actions_spec->capturing_group.back()->text_reply->priority_score = 1.0; actions_spec->capturing_group.back()->text_reply->score = 1.0; return action_id; } int AddRuleMatch(const std::vector& action_ids, RulesModel_::GrammarRulesT* action_grammar_rules) const { const int rule_match_id = action_grammar_rules->rule_match.size(); action_grammar_rules->rule_match.emplace_back( new RulesModel_::GrammarRules_::RuleMatchT); action_grammar_rules->rule_match.back()->action_id.insert( action_grammar_rules->rule_match.back()->action_id.end(), action_ids.begin(), action_ids.end()); return rule_match_id; } std::unique_ptr unilib_; const std::string serialized_entity_data_schema_; std::unique_ptr entity_data_builder_; }; TEST_F(GrammarActionsTest, ProducesSmartReplies) { LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); // Create test rules. // Rule: ^knock knock.?$ -> "Who's there?", "Yes?" RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); rules.Add( "", {"<^>", "knock", "knock", ".?", "<$>"}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules), AddSmartReplySpec("Yes?", &action_grammar_rules)}, &action_grammar_rules)); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get()); std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"Knock knock"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Yes?"))); } TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) { // Create test rules. // Rule: ^Text to RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); rules.Add( "", {"<^>", "text", "", "to", ""}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({AddCapturingMatchSmartReplySpec( /*match_id=*/0, &action_grammar_rules)}, &action_grammar_rules)); // ::= unsubscribe | cancel | confirm | receive rules.Add("", {"unsubscribe"}); rules.Add("", {"cancel"}); rules.Add("", {"confirm"}); rules.Add("", {"receive"}); // ::= help | stop | cancel | yes rules.Add("", {"help"}); rules.Add("", {"stop"}); rules.Add("", {"cancel"}); rules.Add("", {"yes"}); rules.AddValueMapping("", {""}, /*value=*/0); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get()); { std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"Text YES to confirm your subscription"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsSmartReply("YES"))); } { std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"text Stop to cancel your order"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsSmartReply("Stop"))); } } TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) { // Create test rules. // Rule: please dial RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); rules.Add( "", {"please", "dial", ""}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"", /*annotations=*/{{0 /*value*/, "phone"}}, &action_grammar_rules)}, &action_grammar_rules)); // phone ::= +00 00 000 00 00 rules.AddValueMapping("", {"+", "<2_digits>", "<2_digits>", "<3_digits>", "<2_digits>", "<2_digits>"}, /*value=*/0); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get()); std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone"))); EXPECT_THAT(result.front().annotations, ElementsAre(IsActionSuggestionAnnotation( "phone", "+41 79 123 45 67", CodepointSpan{12, 28}))); } TEST_F(GrammarActionsTest, HandlesLocales) { // Create test rules. // Rule: ^knock knock.?$ -> "Who's there?" RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({"", "fr-CH"}); grammar::Rules rules(locale_shard_map); rules.Add( "", {"<^>", "knock", "knock", ".?", "<$>"}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)}, &action_grammar_rules)); rules.Add( "", {""}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)}, &action_grammar_rules), /*max_whitespace_gap=*/-1, /*case_sensitive=*/false, /*shard=*/1); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); // Set locales for rules. action_grammar_rules.rules->rules.back()->locale.emplace_back( new LanguageTagT); action_grammar_rules.rules->rules.back()->locale.back()->language = "fr"; OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get()); // Check default. { std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"UTC", /*annotations=*/{}, /*detected_text_language_tags=*/"en"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"))); } // Check fr. { std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"UTC", /*annotations=*/{}, /*detected_text_language_tags=*/"fr-CH"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Qui est là?"))); } } TEST_F(GrammarActionsTest, HandlesAssertions) { // Create test rules. // Rule: -> Track flight. RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); rules.Add("", {"lx"}); rules.Add("", {"aa"}); rules.Add("", {"<2_digits>"}); rules.Add("", {"<3_digits>"}); rules.Add("", {"<4_digits>"}); // Capture flight code. rules.AddValueMapping("", {"", ""}, /*value=*/0); // Flight: carrier + flight code and check right context. rules.Add( "", {"", "?"}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"", /*annotations=*/{{0 /*value*/, "flight"}}, &action_grammar_rules)}, &action_grammar_rules)); // Exclude matches like: LX 38.00 etc. rules.AddAssertion("", {".?", ""}, /*negative=*/true); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get()); std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"LX38 aa 44 LX 38.38"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsActionOfType("track_flight"), IsActionOfType("track_flight"))); EXPECT_THAT(result[0].annotations, ElementsAre(IsActionSuggestionAnnotation("flight", "LX38", CodepointSpan{0, 4}))); EXPECT_THAT(result[1].annotations, ElementsAre(IsActionSuggestionAnnotation("flight", "aa 44", CodepointSpan{5, 10}))); } TEST_F(GrammarActionsTest, SetsFixedEntityData) { // Create test rules. // Rule: ^hello there$ RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); // Create smart reply and static entity data. const int spec_id = AddSmartReplySpec("General Kenobi!", &action_grammar_rules); std::unique_ptr entity_data = entity_data_builder_->NewRoot(); entity_data->Set("person", "Kenobi"); action_grammar_rules.actions[spec_id]->action->serialized_entity_data = entity_data->Serialize(); action_grammar_rules.actions[spec_id]->action->entity_data.reset( new ActionsEntityDataT); action_grammar_rules.actions[spec_id]->action->entity_data->text = "I have the high ground."; rules.Add( "", {"<^>", "hello", "there", "<$>"}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({spec_id}, &action_grammar_rules)); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get(), entity_data_builder_.get()); std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result)); // Check the produces smart replies. EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!"))); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( result[0].serialized_entity_data.data())); EXPECT_THAT( entity->GetPointer(/*field=*/4)->str(), "I have the high ground."); EXPECT_THAT( entity->GetPointer(/*field=*/8)->str(), "Kenobi"); } TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) { // Create test rules. // Rule: ^hello there$ RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); // Create smart reply and static entity data. const int spec_id = AddSmartReplySpec("General Kenobi!", &action_grammar_rules); std::unique_ptr entity_data = entity_data_builder_->NewRoot(); entity_data->Set("person", "Kenobi"); action_grammar_rules.actions[spec_id]->action->serialized_entity_data = entity_data->Serialize(); // Specify results for capturing matches. const int greeting_match_id = 0; const int location_match_id = 1; { action_grammar_rules.actions[spec_id]->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group = action_grammar_rules.actions[spec_id]->capturing_group.back().get(); group->group_id = greeting_match_id; group->entity_field.reset(new FlatbufferFieldPathT); group->entity_field->field.emplace_back(new FlatbufferFieldT); group->entity_field->field.back()->field_name = "greeting"; } { action_grammar_rules.actions[spec_id]->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group = action_grammar_rules.actions[spec_id]->capturing_group.back().get(); group->group_id = location_match_id; group->entity_field.reset(new FlatbufferFieldPathT); group->entity_field->field.emplace_back(new FlatbufferFieldT); group->entity_field->field.back()->field_name = "location"; } rules.Add("", {"there"}); rules.Add("", {"here"}); rules.AddValueMapping("", {""}, /*value=*/location_match_id); rules.AddValueMapping("", {"hello", ""}, /*value=*/greeting_match_id); rules.Add( "", {"<^>", "", "<$>"}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({spec_id}, &action_grammar_rules)); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get(), entity_data_builder_.get()); std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result)); // Check the produces smart replies. EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!"))); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( result[0].serialized_entity_data.data())); EXPECT_THAT( entity->GetPointer(/*field=*/4)->str(), "Hello there"); EXPECT_THAT( entity->GetPointer(/*field=*/6)->str(), "there"); EXPECT_THAT( entity->GetPointer(/*field=*/8)->str(), "Kenobi"); } TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) { // Create test rules. // Rule: ^hello there$ RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); // Create smart reply. const int spec_id = AddSmartReplySpec("General Kenobi!", &action_grammar_rules); action_grammar_rules.actions[spec_id]->capturing_group.emplace_back( new RulesModel_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group = action_grammar_rules.actions[spec_id]->capturing_group.back().get(); group->group_id = 0; group->entity_data.reset(new ActionsEntityDataT); group->entity_data->text = "You are a bold one."; rules.AddValueMapping("", {"<^>", "hello", "there", "<$>"}, /*value=*/0); rules.Add( "", {""}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({spec_id}, &action_grammar_rules)); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get(), entity_data_builder_.get()); std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result)); // Check the produces smart replies. EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!"))); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( result[0].serialized_entity_data.data())); EXPECT_THAT( entity->GetPointer(/*field=*/4)->str(), "You are a bold one."); } TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) { // Create test rules. // Rule: please dial RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); rules.Add( "", {"please", "dial", ""}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"", /*annotations=*/ {{0 /*value*/, "phone", /*use_annotation_match=*/true}}, &action_grammar_rules)}, &action_grammar_rules)); rules.AddValueMapping("", {""}, /*value=*/0); grammar::Ir ir = rules.Finalize( /*predefined_nonterminals=*/{""}); ir.Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); // Map "phone" annotation to "" nonterminal. action_grammar_rules.rules->nonterminals->annotation_nt.emplace_back( new grammar::RulesSet_::Nonterminals_::AnnotationNtEntryT); action_grammar_rules.rules->nonterminals->annotation_nt.back()->key = "phone"; action_grammar_rules.rules->nonterminals->annotation_nt.back()->value = ir.GetNonterminalForName(""); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get()); std::vector result; // Sanity check that no result are produced when no annotations are provided. EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}}, &result)); EXPECT_THAT(result, IsEmpty()); EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{ {/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"UTC", /*annotations=*/ {{CodepointSpan{12, 28}, {ClassificationResult{"phone", 1.0}}}}}}}, &result)); EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone"))); EXPECT_THAT(result.front().annotations, ElementsAre(IsActionSuggestionAnnotation( "phone", "+41 79 123 45 67", CodepointSpan{12, 28}))); } TEST_F(GrammarActionsTest, HandlesExclusions) { // Create test rules. RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); grammar::Rules rules(locale_shard_map); rules.Add("", {"be", "safe"}); rules.AddWithExclusion("", {"", ""}, /*excluded_nonterminal=*/""); rules.Add( "", {"do", "not", "forget", "to", ""}, /*callback=*/ static_cast(grammar::DefaultCallback::kRootRule), /*callback_param=*/ AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"", /*annotations=*/ {}, &action_grammar_rules)}, &action_grammar_rules)); rules.Finalize().Serialize(/*include_debug_information=*/false, action_grammar_rules.rules.get()); OwnedFlatbuffer model( PackFlatbuffer(&action_grammar_rules)); TestGrammarActions grammar_actions(unilib_.get(), model.get(), entity_data_builder_.get()); { std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{ {/*user_id=*/0, /*text=*/"do not forget to bring milk"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder"))); } { std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be there!"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder"))); } { std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{ {/*user_id=*/0, /*text=*/"do not forget to buy safe or vault!"}}}, &result)); EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder"))); } { std::vector result; EXPECT_TRUE(grammar_actions.SuggestActions( {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be safe!"}}}, &result)); EXPECT_THAT(result, IsEmpty()); } } } // namespace } // namespace libtextclassifier3