You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
709 lines
29 KiB
709 lines
29 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "actions/grammar-actions.h"
|
|
|
|
#include <iostream>
|
|
#include <memory>
|
|
|
|
#include "actions/actions_model_generated.h"
|
|
#include "actions/test-utils.h"
|
|
#include "actions/types.h"
|
|
#include "utils/flatbuffers/flatbuffers.h"
|
|
#include "utils/flatbuffers/mutable.h"
|
|
#include "utils/grammar/rules_generated.h"
|
|
#include "utils/grammar/types.h"
|
|
#include "utils/grammar/utils/rules.h"
|
|
#include "utils/jvm-test-utils.h"
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
|
|
using ::testing::ElementsAre;
|
|
using ::testing::IsEmpty;
|
|
|
|
using ::libtextclassifier3::grammar::LocaleShardMap;
|
|
|
|
class TestGrammarActions : public GrammarActions {
|
|
public:
|
|
explicit TestGrammarActions(
|
|
const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
|
|
const MutableFlatbufferBuilder* entity_data_builder = nullptr)
|
|
: GrammarActions(unilib, grammar_rules, entity_data_builder,
|
|
|
|
/*smart_reply_action_type=*/"text_reply") {}
|
|
};
|
|
|
|
class GrammarActionsTest : public testing::Test {
|
|
protected:
|
|
struct AnnotationSpec {
|
|
int group_id = 0;
|
|
std::string annotation_name = "";
|
|
bool use_annotation_match = false;
|
|
};
|
|
|
|
GrammarActionsTest()
|
|
: unilib_(CreateUniLibForTesting()),
|
|
serialized_entity_data_schema_(TestEntityDataSchema()),
|
|
entity_data_builder_(new MutableFlatbufferBuilder(
|
|
flatbuffers::GetRoot<reflection::Schema>(
|
|
serialized_entity_data_schema_.data()))) {}
|
|
|
|
void SetTokenizerOptions(
|
|
RulesModel_::GrammarRulesT* action_grammar_rules) const {
|
|
action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
|
|
action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
|
|
action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
|
|
false;
|
|
}
|
|
|
|
int AddActionSpec(const std::string& type, const std::string& response_text,
|
|
const std::vector<AnnotationSpec>& annotations,
|
|
RulesModel_::GrammarRulesT* action_grammar_rules) const {
|
|
const int action_id = action_grammar_rules->actions.size();
|
|
action_grammar_rules->actions.emplace_back(
|
|
new RulesModel_::RuleActionSpecT);
|
|
RulesModel_::RuleActionSpecT* actions_spec =
|
|
action_grammar_rules->actions.back().get();
|
|
actions_spec->action.reset(new ActionSuggestionSpecT);
|
|
actions_spec->action->response_text = response_text;
|
|
actions_spec->action->priority_score = 1.0;
|
|
actions_spec->action->score = 1.0;
|
|
actions_spec->action->type = type;
|
|
// Create annotations for specified capturing groups.
|
|
for (const AnnotationSpec& annotation : annotations) {
|
|
actions_spec->capturing_group.emplace_back(
|
|
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
|
|
actions_spec->capturing_group.back()->group_id = annotation.group_id;
|
|
actions_spec->capturing_group.back()->annotation_name =
|
|
annotation.annotation_name;
|
|
actions_spec->capturing_group.back()->annotation_type =
|
|
annotation.annotation_name;
|
|
actions_spec->capturing_group.back()->use_annotation_match =
|
|
annotation.use_annotation_match;
|
|
}
|
|
|
|
return action_id;
|
|
}
|
|
|
|
int AddSmartReplySpec(
|
|
const std::string& response_text,
|
|
RulesModel_::GrammarRulesT* action_grammar_rules) const {
|
|
return AddActionSpec("text_reply", response_text, {}, action_grammar_rules);
|
|
}
|
|
|
|
int AddCapturingMatchSmartReplySpec(
|
|
const int match_id,
|
|
RulesModel_::GrammarRulesT* action_grammar_rules) const {
|
|
const int action_id = action_grammar_rules->actions.size();
|
|
action_grammar_rules->actions.emplace_back(
|
|
new RulesModel_::RuleActionSpecT);
|
|
RulesModel_::RuleActionSpecT* actions_spec =
|
|
action_grammar_rules->actions.back().get();
|
|
actions_spec->capturing_group.emplace_back(
|
|
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
|
|
actions_spec->capturing_group.back()->group_id = match_id;
|
|
actions_spec->capturing_group.back()->text_reply.reset(
|
|
new ActionSuggestionSpecT);
|
|
actions_spec->capturing_group.back()->text_reply->priority_score = 1.0;
|
|
actions_spec->capturing_group.back()->text_reply->score = 1.0;
|
|
return action_id;
|
|
}
|
|
|
|
int AddRuleMatch(const std::vector<int>& action_ids,
|
|
RulesModel_::GrammarRulesT* action_grammar_rules) const {
|
|
const int rule_match_id = action_grammar_rules->rule_match.size();
|
|
action_grammar_rules->rule_match.emplace_back(
|
|
new RulesModel_::GrammarRules_::RuleMatchT);
|
|
action_grammar_rules->rule_match.back()->action_id.insert(
|
|
action_grammar_rules->rule_match.back()->action_id.end(),
|
|
action_ids.begin(), action_ids.end());
|
|
return rule_match_id;
|
|
}
|
|
|
|
std::unique_ptr<UniLib> unilib_;
|
|
const std::string serialized_entity_data_schema_;
|
|
std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
|
|
};
|
|
|
|
TEST_F(GrammarActionsTest, ProducesSmartReplies) {
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
|
|
// Create test rules.
|
|
// Rule: ^knock knock.?$ -> "Who's there?", "Yes?"
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
rules.Add(
|
|
"<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules),
|
|
AddSmartReplySpec("Yes?", &action_grammar_rules)},
|
|
&action_grammar_rules));
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get());
|
|
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"Knock knock"}}}, &result));
|
|
|
|
EXPECT_THAT(result,
|
|
ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Yes?")));
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) {
|
|
// Create test rules.
|
|
// Rule: ^Text <reply> to <command>
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
|
|
rules.Add(
|
|
"<scripted_reply>",
|
|
{"<^>", "text", "<captured_reply>", "to", "<command>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({AddCapturingMatchSmartReplySpec(
|
|
/*match_id=*/0, &action_grammar_rules)},
|
|
&action_grammar_rules));
|
|
|
|
// <command> ::= unsubscribe | cancel | confirm | receive
|
|
rules.Add("<command>", {"unsubscribe"});
|
|
rules.Add("<command>", {"cancel"});
|
|
rules.Add("<command>", {"confirm"});
|
|
rules.Add("<command>", {"receive"});
|
|
|
|
// <reply> ::= help | stop | cancel | yes
|
|
rules.Add("<reply>", {"help"});
|
|
rules.Add("<reply>", {"stop"});
|
|
rules.Add("<reply>", {"cancel"});
|
|
rules.Add("<reply>", {"yes"});
|
|
rules.AddValueMapping("<captured_reply>", {"<reply>"},
|
|
/*value=*/0);
|
|
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get());
|
|
|
|
{
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0,
|
|
/*text=*/"Text YES to confirm your subscription"}}},
|
|
&result));
|
|
EXPECT_THAT(result, ElementsAre(IsSmartReply("YES")));
|
|
}
|
|
|
|
{
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0,
|
|
/*text=*/"text Stop to cancel your order"}}},
|
|
&result));
|
|
EXPECT_THAT(result, ElementsAre(IsSmartReply("Stop")));
|
|
}
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) {
|
|
// Create test rules.
|
|
// Rule: please dial <phone>
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
|
|
rules.Add(
|
|
"<call_phone>", {"please", "dial", "<phone>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
|
|
/*annotations=*/{{0 /*value*/, "phone"}},
|
|
&action_grammar_rules)},
|
|
&action_grammar_rules));
|
|
// phone ::= +00 00 000 00 00
|
|
rules.AddValueMapping("<phone>",
|
|
{"+", "<2_digits>", "<2_digits>", "<3_digits>",
|
|
"<2_digits>", "<2_digits>"},
|
|
/*value=*/0);
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get());
|
|
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
|
|
&result));
|
|
|
|
EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
|
|
EXPECT_THAT(result.front().annotations,
|
|
ElementsAre(IsActionSuggestionAnnotation(
|
|
"phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, HandlesLocales) {
|
|
// Create test rules.
|
|
// Rule: ^knock knock.?$ -> "Who's there?"
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
LocaleShardMap locale_shard_map =
|
|
LocaleShardMap::CreateLocaleShardMap({"", "fr-CH"});
|
|
grammar::Rules rules(locale_shard_map);
|
|
rules.Add(
|
|
"<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)},
|
|
&action_grammar_rules));
|
|
rules.Add(
|
|
"<toc>", {"<knock>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)},
|
|
&action_grammar_rules),
|
|
/*max_whitespace_gap=*/-1,
|
|
/*case_sensitive=*/false,
|
|
/*shard=*/1);
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
// Set locales for rules.
|
|
action_grammar_rules.rules->rules.back()->locale.emplace_back(
|
|
new LanguageTagT);
|
|
action_grammar_rules.rules->rules.back()->locale.back()->language = "fr";
|
|
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get());
|
|
|
|
// Check default.
|
|
{
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
|
|
/*reference_time_ms_utc=*/0,
|
|
/*reference_timezone=*/"UTC", /*annotations=*/{},
|
|
/*detected_text_language_tags=*/"en"}}},
|
|
&result));
|
|
|
|
EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?")));
|
|
}
|
|
|
|
// Check fr.
|
|
{
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
|
|
/*reference_time_ms_utc=*/0,
|
|
/*reference_timezone=*/"UTC", /*annotations=*/{},
|
|
/*detected_text_language_tags=*/"fr-CH"}}},
|
|
&result));
|
|
|
|
EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"),
|
|
IsSmartReply("Qui est là?")));
|
|
}
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, HandlesAssertions) {
|
|
// Create test rules.
|
|
// Rule: <flight> -> Track flight.
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
rules.Add("<carrier>", {"lx"});
|
|
rules.Add("<carrier>", {"aa"});
|
|
rules.Add("<flight_code>", {"<2_digits>"});
|
|
rules.Add("<flight_code>", {"<3_digits>"});
|
|
rules.Add("<flight_code>", {"<4_digits>"});
|
|
|
|
// Capture flight code.
|
|
rules.AddValueMapping("<flight>", {"<carrier>", "<flight_code>"},
|
|
/*value=*/0);
|
|
|
|
// Flight: carrier + flight code and check right context.
|
|
rules.Add(
|
|
"<track_flight>", {"<flight>", "<context_assertion>?"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"",
|
|
/*annotations=*/{{0 /*value*/, "flight"}},
|
|
&action_grammar_rules)},
|
|
&action_grammar_rules));
|
|
|
|
// Exclude matches like: LX 38.00 etc.
|
|
rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
|
|
/*negative=*/true);
|
|
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get());
|
|
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"LX38 aa 44 LX 38.38"}}},
|
|
&result));
|
|
|
|
EXPECT_THAT(result, ElementsAre(IsActionOfType("track_flight"),
|
|
IsActionOfType("track_flight")));
|
|
EXPECT_THAT(result[0].annotations,
|
|
ElementsAre(IsActionSuggestionAnnotation("flight", "LX38",
|
|
CodepointSpan{0, 4})));
|
|
EXPECT_THAT(result[1].annotations,
|
|
ElementsAre(IsActionSuggestionAnnotation("flight", "aa 44",
|
|
CodepointSpan{5, 10})));
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, SetsFixedEntityData) {
|
|
// Create test rules.
|
|
// Rule: ^hello there$
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
|
|
// Create smart reply and static entity data.
|
|
const int spec_id =
|
|
AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
|
|
std::unique_ptr<MutableFlatbuffer> entity_data =
|
|
entity_data_builder_->NewRoot();
|
|
entity_data->Set("person", "Kenobi");
|
|
action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
|
|
entity_data->Serialize();
|
|
action_grammar_rules.actions[spec_id]->action->entity_data.reset(
|
|
new ActionsEntityDataT);
|
|
action_grammar_rules.actions[spec_id]->action->entity_data->text =
|
|
"I have the high ground.";
|
|
|
|
rules.Add(
|
|
"<greeting>", {"<^>", "hello", "there", "<$>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({spec_id}, &action_grammar_rules));
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get(),
|
|
entity_data_builder_.get());
|
|
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
|
|
|
|
// Check the produces smart replies.
|
|
EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
|
|
|
|
// Check entity data.
|
|
const flatbuffers::Table* entity =
|
|
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
|
|
result[0].serialized_entity_data.data()));
|
|
EXPECT_THAT(
|
|
entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
|
|
"I have the high ground.");
|
|
EXPECT_THAT(
|
|
entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
|
|
"Kenobi");
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) {
|
|
// Create test rules.
|
|
// Rule: ^hello there$
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
|
|
// Create smart reply and static entity data.
|
|
const int spec_id =
|
|
AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
|
|
std::unique_ptr<MutableFlatbuffer> entity_data =
|
|
entity_data_builder_->NewRoot();
|
|
entity_data->Set("person", "Kenobi");
|
|
action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
|
|
entity_data->Serialize();
|
|
|
|
// Specify results for capturing matches.
|
|
const int greeting_match_id = 0;
|
|
const int location_match_id = 1;
|
|
{
|
|
action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
|
|
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
|
|
RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
|
|
action_grammar_rules.actions[spec_id]->capturing_group.back().get();
|
|
group->group_id = greeting_match_id;
|
|
group->entity_field.reset(new FlatbufferFieldPathT);
|
|
group->entity_field->field.emplace_back(new FlatbufferFieldT);
|
|
group->entity_field->field.back()->field_name = "greeting";
|
|
}
|
|
{
|
|
action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
|
|
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
|
|
RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
|
|
action_grammar_rules.actions[spec_id]->capturing_group.back().get();
|
|
group->group_id = location_match_id;
|
|
group->entity_field.reset(new FlatbufferFieldPathT);
|
|
group->entity_field->field.emplace_back(new FlatbufferFieldT);
|
|
group->entity_field->field.back()->field_name = "location";
|
|
}
|
|
|
|
rules.Add("<location>", {"there"});
|
|
rules.Add("<location>", {"here"});
|
|
rules.AddValueMapping("<captured_location>", {"<location>"},
|
|
/*value=*/location_match_id);
|
|
rules.AddValueMapping("<greeting>", {"hello", "<captured_location>"},
|
|
/*value=*/greeting_match_id);
|
|
rules.Add(
|
|
"<test>", {"<^>", "<greeting>", "<$>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({spec_id}, &action_grammar_rules));
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get(),
|
|
entity_data_builder_.get());
|
|
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
|
|
|
|
// Check the produces smart replies.
|
|
EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
|
|
|
|
// Check entity data.
|
|
const flatbuffers::Table* entity =
|
|
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
|
|
result[0].serialized_entity_data.data()));
|
|
EXPECT_THAT(
|
|
entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
|
|
"Hello there");
|
|
EXPECT_THAT(
|
|
entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
|
|
"there");
|
|
EXPECT_THAT(
|
|
entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
|
|
"Kenobi");
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) {
|
|
// Create test rules.
|
|
// Rule: ^hello there$
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
|
|
// Create smart reply.
|
|
const int spec_id =
|
|
AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
|
|
action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
|
|
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
|
|
RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
|
|
action_grammar_rules.actions[spec_id]->capturing_group.back().get();
|
|
group->group_id = 0;
|
|
group->entity_data.reset(new ActionsEntityDataT);
|
|
group->entity_data->text = "You are a bold one.";
|
|
|
|
rules.AddValueMapping("<greeting>", {"<^>", "hello", "there", "<$>"},
|
|
/*value=*/0);
|
|
rules.Add(
|
|
"<test>", {"<greeting>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({spec_id}, &action_grammar_rules));
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get(),
|
|
entity_data_builder_.get());
|
|
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
|
|
|
|
// Check the produces smart replies.
|
|
EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
|
|
|
|
// Check entity data.
|
|
const flatbuffers::Table* entity =
|
|
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
|
|
result[0].serialized_entity_data.data()));
|
|
EXPECT_THAT(
|
|
entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
|
|
"You are a bold one.");
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) {
|
|
// Create test rules.
|
|
// Rule: please dial <phone>
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
rules.Add(
|
|
"<call_phone>", {"please", "dial", "<phone>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
|
|
/*annotations=*/
|
|
{{0 /*value*/, "phone",
|
|
/*use_annotation_match=*/true}},
|
|
&action_grammar_rules)},
|
|
&action_grammar_rules));
|
|
rules.AddValueMapping("<phone>", {"<phone_annotation>"},
|
|
/*value=*/0);
|
|
|
|
grammar::Ir ir = rules.Finalize(
|
|
/*predefined_nonterminals=*/{"<phone_annotation>"});
|
|
ir.Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
|
|
// Map "phone" annotation to "<phone_annotation>" nonterminal.
|
|
action_grammar_rules.rules->nonterminals->annotation_nt.emplace_back(
|
|
new grammar::RulesSet_::Nonterminals_::AnnotationNtEntryT);
|
|
action_grammar_rules.rules->nonterminals->annotation_nt.back()->key = "phone";
|
|
action_grammar_rules.rules->nonterminals->annotation_nt.back()->value =
|
|
ir.GetNonterminalForName("<phone_annotation>");
|
|
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get());
|
|
|
|
std::vector<ActionSuggestion> result;
|
|
|
|
// Sanity check that no result are produced when no annotations are provided.
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
|
|
&result));
|
|
EXPECT_THAT(result, IsEmpty());
|
|
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{
|
|
{/*user_id=*/0,
|
|
/*text=*/"Please dial +41 79 123 45 67",
|
|
/*reference_time_ms_utc=*/0,
|
|
/*reference_timezone=*/"UTC",
|
|
/*annotations=*/
|
|
{{CodepointSpan{12, 28}, {ClassificationResult{"phone", 1.0}}}}}}},
|
|
&result));
|
|
EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
|
|
EXPECT_THAT(result.front().annotations,
|
|
ElementsAre(IsActionSuggestionAnnotation(
|
|
"phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
|
|
}
|
|
|
|
TEST_F(GrammarActionsTest, HandlesExclusions) {
|
|
// Create test rules.
|
|
RulesModel_::GrammarRulesT action_grammar_rules;
|
|
SetTokenizerOptions(&action_grammar_rules);
|
|
action_grammar_rules.rules.reset(new grammar::RulesSetT);
|
|
|
|
LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
|
|
grammar::Rules rules(locale_shard_map);
|
|
rules.Add("<excluded>", {"be", "safe"});
|
|
rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
|
|
/*excluded_nonterminal=*/"<excluded>");
|
|
|
|
rules.Add(
|
|
"<set_reminder>",
|
|
{"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
|
|
/*callback=*/
|
|
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
|
|
/*callback_param=*/
|
|
AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"",
|
|
/*annotations=*/
|
|
{}, &action_grammar_rules)},
|
|
&action_grammar_rules));
|
|
|
|
rules.Finalize().Serialize(/*include_debug_information=*/false,
|
|
action_grammar_rules.rules.get());
|
|
OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
|
|
PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
|
|
TestGrammarActions grammar_actions(unilib_.get(), model.get(),
|
|
entity_data_builder_.get());
|
|
|
|
{
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{
|
|
{/*user_id=*/0, /*text=*/"do not forget to bring milk"}}},
|
|
&result));
|
|
EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
|
|
}
|
|
|
|
{
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be there!"}}},
|
|
&result));
|
|
EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
|
|
}
|
|
|
|
{
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{
|
|
{/*user_id=*/0, /*text=*/"do not forget to buy safe or vault!"}}},
|
|
&result));
|
|
EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
|
|
}
|
|
|
|
{
|
|
std::vector<ActionSuggestion> result;
|
|
EXPECT_TRUE(grammar_actions.SuggestActions(
|
|
{/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be safe!"}}},
|
|
&result));
|
|
EXPECT_THAT(result, IsEmpty());
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace libtextclassifier3
|