You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
383 lines
16 KiB
383 lines
16 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "actions/ranker.h"
|
|
|
|
#include <string>
|
|
|
|
#include "actions/types.h"
|
|
#include "utils/zlib/zlib.h"
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
|
|
MATCHER_P3(IsAction, type, response_text, score, "") {
|
|
return testing::Value(arg.type, type) &&
|
|
testing::Value(arg.response_text, response_text) &&
|
|
testing::Value(arg.score, score);
|
|
}
|
|
|
|
MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
|
|
|
|
TEST(RankingTest, DeduplicationSmartReply) {
|
|
const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
|
|
ActionsSuggestionsResponse response;
|
|
response.actions = {
|
|
{/*response_text=*/"hello there", /*type=*/"text_reply",
|
|
/*score=*/1.0},
|
|
{/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};
|
|
|
|
RankingOptionsT options;
|
|
options.deduplicate_suggestions = true;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
EXPECT_THAT(
|
|
response.actions,
|
|
testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
|
|
}
|
|
|
|
TEST(RankingTest, DeduplicationExtraData) {
|
|
const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
|
|
ActionsSuggestionsResponse response;
|
|
response.actions = {
|
|
{/*response_text=*/"hello there", /*type=*/"text_reply",
|
|
/*score=*/1.0, /*priority_score=*/0.0},
|
|
{/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
|
|
/*priority_score=*/0.0},
|
|
{/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
|
|
/*priority_score=*/0.0,
|
|
/*annotations=*/{}, /*serialized_entity_data=*/"test"},
|
|
};
|
|
|
|
RankingOptionsT options;
|
|
options.deduplicate_suggestions = true;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
EXPECT_THAT(
|
|
response.actions,
|
|
testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
|
|
// Is kept as it has different entity data.
|
|
IsAction("text_reply", "hello there", 0.6)}));
|
|
}
|
|
|
|
TEST(RankingTest, DeduplicationAnnotations) {
|
|
const Conversation conversation = {
|
|
{{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
|
|
ActionsSuggestionsResponse response;
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
|
|
/*text=*/"742 Evergreen Terrace"};
|
|
annotation.entity = ClassificationResult("address", 0.5);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"view_map",
|
|
/*score=*/0.5,
|
|
/*priority_score=*/1.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
|
|
/*text=*/"742 Evergreen Terrace"};
|
|
annotation.entity = ClassificationResult("address", 1.0);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"view_map",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/2.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
|
|
/*text=*/"1-800-TESTING"};
|
|
annotation.entity = ClassificationResult("phone", 0.5);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"call_phone",
|
|
/*score=*/0.5,
|
|
/*priority_score=*/1.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
|
|
RankingOptionsT options;
|
|
options.deduplicate_suggestions = true;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
EXPECT_THAT(response.actions,
|
|
testing::ElementsAreArray({IsAction("view_map", "", 1.0),
|
|
IsAction("call_phone", "", 0.5)}));
|
|
}
|
|
|
|
TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
|
|
const Conversation conversation = {
|
|
{{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
|
|
ActionsSuggestionsResponse response;
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
|
|
/*text=*/"742 Evergreen Terrace"};
|
|
annotation.entity = ClassificationResult("address", 0.5);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"view_map",
|
|
/*score=*/0.6,
|
|
/*priority_score=*/2.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
|
|
/*text=*/"742 Evergreen Terrace"};
|
|
annotation.entity = ClassificationResult("address", 1.0);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"view_map",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/1.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
|
|
/*text=*/"1-800-TESTING"};
|
|
annotation.entity = ClassificationResult("phone", 0.5);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"call_phone",
|
|
/*score=*/0.5,
|
|
/*priority_score=*/1.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
|
|
RankingOptionsT options;
|
|
options.deduplicate_suggestions = true;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
EXPECT_THAT(
|
|
response.actions,
|
|
testing::ElementsAreArray(
|
|
{IsAction("view_map", "",
|
|
0.6), // lower score wins, as priority score is higher
|
|
IsAction("call_phone", "", 0.5)}));
|
|
}
|
|
|
|
TEST(RankingTest, DeduplicatesConflictingActions) {
|
|
const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
|
|
ActionsSuggestionsResponse response;
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
|
|
/*text=*/"911"};
|
|
annotation.entity = ClassificationResult("phone", 1.0);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"call_phone",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/1.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
|
|
/*text=*/"A-911"};
|
|
annotation.entity = ClassificationResult("code", 1.0);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"copy_code",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/2.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
RankingOptionsT options;
|
|
options.deduplicate_suggestions = true;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
EXPECT_THAT(response.actions,
|
|
testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
|
|
}
|
|
|
|
TEST(RankingTest, HandlesCompressedLuaScript) {
|
|
const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
|
|
ActionsSuggestionsResponse response;
|
|
response.actions = {
|
|
{/*response_text=*/"hello there", /*type=*/"text_reply",
|
|
/*score=*/1.0},
|
|
{/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
|
|
{/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
|
|
const std::string test_snippet = R"(
|
|
local result = {}
|
|
for id, action in pairs(actions) do
|
|
if action.type ~= "text_reply" then
|
|
table.insert(result, id)
|
|
end
|
|
end
|
|
return result
|
|
)";
|
|
RankingOptionsT options;
|
|
options.compressed_lua_ranking_script.reset(new CompressedBufferT);
|
|
std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
|
|
compressor->Compress(test_snippet,
|
|
options.compressed_lua_ranking_script.get());
|
|
options.deduplicate_suggestions = true;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
|
|
std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
decompressor.get(), /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
EXPECT_THAT(response.actions,
|
|
testing::ElementsAreArray({IsActionType("share_location"),
|
|
IsActionType("add_to_collection")}));
|
|
}
|
|
|
|
TEST(RankingTest, SuppressSmartRepliesWithAction) {
|
|
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
|
|
ActionsSuggestionsResponse response;
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
|
|
/*text=*/"911"};
|
|
annotation.entity = ClassificationResult("phone", 1.0);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"call_phone",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/1.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
response.actions.push_back({/*response_text=*/"How are you?",
|
|
/*type=*/"text_reply"});
|
|
RankingOptionsT options;
|
|
options.suppress_smart_replies_with_actions = true;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
|
|
EXPECT_THAT(response.actions,
|
|
testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
|
|
}
|
|
|
|
TEST(RankingTest, GroupsActionsByAnnotations) {
|
|
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
|
|
ActionsSuggestionsResponse response;
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
|
|
/*text=*/"911"};
|
|
annotation.entity = ClassificationResult("phone", 1.0);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"call_phone",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/1.0,
|
|
/*annotations=*/{annotation}});
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"add_contact",
|
|
/*score=*/0.0,
|
|
/*priority_score=*/0.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
response.actions.push_back({/*response_text=*/"How are you?",
|
|
/*type=*/"text_reply",
|
|
/*score=*/0.5});
|
|
RankingOptionsT options;
|
|
options.group_by_annotations = true;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
|
|
// The text reply should be last, even though it has a higher score than the
|
|
// `add_contact` action.
|
|
EXPECT_THAT(
|
|
response.actions,
|
|
testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
|
|
IsAction("add_contact", "", 0.0),
|
|
IsAction("text_reply", "How are you?", 0.5)}));
|
|
}
|
|
|
|
TEST(RankingTest, SortsActionsByScore) {
|
|
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
|
|
ActionsSuggestionsResponse response;
|
|
{
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
|
|
/*text=*/"911"};
|
|
annotation.entity = ClassificationResult("phone", 1.0);
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"call_phone",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/1.0,
|
|
/*annotations=*/{annotation}});
|
|
response.actions.push_back({/*response_text=*/"",
|
|
/*type=*/"add_contact",
|
|
/*score=*/0.0,
|
|
/*priority_score=*/0.0,
|
|
/*annotations=*/{annotation}});
|
|
}
|
|
response.actions.push_back({/*response_text=*/"How are you?",
|
|
/*type=*/"text_reply",
|
|
/*score=*/0.5});
|
|
RankingOptionsT options;
|
|
// Don't group by annotation.
|
|
options.group_by_annotations = false;
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(RankingOptions::Pack(builder, &options));
|
|
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
|
|
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
|
|
|
|
ranker->RankActions(conversation, &response);
|
|
|
|
EXPECT_THAT(
|
|
response.actions,
|
|
testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
|
|
IsAction("text_reply", "How are you?", 0.5),
|
|
IsAction("add_contact", "", 0.0)}));
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace libtextclassifier3
|