You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
7.7 KiB
213 lines
7.7 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "actions/lua-actions.h"
|
|
|
|
#include <map>
|
|
#include <string>
|
|
|
|
#include "actions/test-utils.h"
|
|
#include "actions/types.h"
|
|
#include "utils/tflite-model-executor.h"
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
|
|
using testing::ElementsAre;
|
|
|
|
TEST(LuaActions, SimpleAction) {
|
|
Conversation conversation;
|
|
const std::string test_snippet = R"(
|
|
return {{ type = "test_action" }}
|
|
)";
|
|
std::vector<ActionSuggestion> actions;
|
|
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
|
|
test_snippet, conversation,
|
|
/*model_executor=*/nullptr,
|
|
/*model_spec=*/nullptr,
|
|
/*interpreter=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr,
|
|
/*annotations_entity_data_schema=*/nullptr)
|
|
->SuggestActions(&actions));
|
|
EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
|
|
}
|
|
|
|
TEST(LuaActions, ConversationActions) {
|
|
Conversation conversation;
|
|
conversation.messages.push_back({/*user_id=*/0, "hello there!"});
|
|
conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
|
|
const std::string test_snippet = R"(
|
|
local actions = {}
|
|
for i, message in pairs(messages) do
|
|
if i < #messages then
|
|
if message.text == "hello there!" and
|
|
messages[i+1].text == "general kenobi!" then
|
|
table.insert(actions, {
|
|
type = "text_reply",
|
|
response_text = "you are a bold one!"
|
|
})
|
|
end
|
|
if message.text == "i am the senate!" and
|
|
messages[i+1].text == "not yet!" then
|
|
table.insert(actions, {
|
|
type = "text_reply",
|
|
response_text = "it's treason then"
|
|
})
|
|
end
|
|
end
|
|
end
|
|
return actions;
|
|
)";
|
|
std::vector<ActionSuggestion> actions;
|
|
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
|
|
test_snippet, conversation,
|
|
/*model_executor=*/nullptr,
|
|
/*model_spec=*/nullptr,
|
|
/*interpreter=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr,
|
|
/*annotations_entity_data_schema=*/nullptr)
|
|
->SuggestActions(&actions));
|
|
EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!")));
|
|
}
|
|
|
|
TEST(LuaActions, SimpleModelAction) {
|
|
Conversation conversation;
|
|
const std::string test_snippet = R"(
|
|
if #model.actions_scores == 0 then
|
|
return {{ type = "test_action" }}
|
|
end
|
|
return {}
|
|
)";
|
|
std::vector<ActionSuggestion> actions;
|
|
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
|
|
test_snippet, conversation,
|
|
/*model_executor=*/nullptr,
|
|
/*model_spec=*/nullptr,
|
|
/*interpreter=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr,
|
|
/*annotations_entity_data_schema=*/nullptr)
|
|
->SuggestActions(&actions));
|
|
EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
|
|
}
|
|
|
|
TEST(LuaActions, SimpleModelRepliesAction) {
|
|
Conversation conversation;
|
|
const std::string test_snippet = R"(
|
|
if #model.reply == 0 then
|
|
return {{ type = "test_action" }}
|
|
end
|
|
return {}
|
|
)";
|
|
std::vector<ActionSuggestion> actions;
|
|
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
|
|
test_snippet, conversation,
|
|
/*model_executor=*/nullptr,
|
|
/*model_spec=*/nullptr,
|
|
/*interpreter=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr,
|
|
/*annotations_entity_data_schema=*/nullptr)
|
|
->SuggestActions(&actions));
|
|
EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
|
|
}
|
|
|
|
TEST(LuaActions, AnnotationActions) {
|
|
AnnotatedSpan annotation;
|
|
annotation.span = {11, 15};
|
|
annotation.classification = {ClassificationResult("address", 1.0)};
|
|
Conversation conversation = {{{/*user_id=*/1, "are you at home?",
|
|
/*reference_time_ms_utc=*/0,
|
|
/*reference_timezone=*/"Europe/Zurich",
|
|
/*annotations=*/{annotation},
|
|
/*locales=*/"en"}}};
|
|
const std::string test_snippet = R"(
|
|
local actions = {}
|
|
local last_message = messages[#messages]
|
|
for i, annotation in pairs(last_message.annotation) do
|
|
if #annotation.classification > 0 then
|
|
if annotation.classification[1].collection == "address" then
|
|
local text = string.sub(last_message.text,
|
|
annotation.span["begin"] + 1,
|
|
annotation.span["end"])
|
|
table.insert(actions, {
|
|
type = "text_reply",
|
|
response_text = "i am at " .. text,
|
|
annotation = {{
|
|
name = "location",
|
|
span = {
|
|
text = text
|
|
},
|
|
entity = annotation.classification[1]
|
|
}},
|
|
})
|
|
end
|
|
end
|
|
end
|
|
return actions;
|
|
)";
|
|
std::vector<ActionSuggestion> actions;
|
|
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
|
|
test_snippet, conversation,
|
|
/*model_executor=*/nullptr,
|
|
/*model_spec=*/nullptr,
|
|
/*interpreter=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr,
|
|
/*annotations_entity_data_schema=*/nullptr)
|
|
->SuggestActions(&actions));
|
|
EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home")));
|
|
EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
|
|
}
|
|
|
|
TEST(LuaActions, EntityData) {
|
|
std::string test_schema = TestEntityDataSchema();
|
|
Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
|
|
const std::string test_snippet = R"(
|
|
return {{
|
|
type = "test",
|
|
entity = {
|
|
greeting = "hello",
|
|
location = "there",
|
|
person = "Kenobi",
|
|
},
|
|
}};
|
|
)";
|
|
std::vector<ActionSuggestion> actions;
|
|
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
|
|
test_snippet, conversation,
|
|
/*model_executor=*/nullptr,
|
|
/*model_spec=*/nullptr,
|
|
/*interpreter=*/nullptr,
|
|
/*actions_entity_data_schema=*/
|
|
flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
|
|
/*annotations_entity_data_schema=*/nullptr)
|
|
->SuggestActions(&actions));
|
|
EXPECT_THAT(actions, testing::SizeIs(1));
|
|
EXPECT_EQ("test", actions.front().type);
|
|
const flatbuffers::Table* entity =
|
|
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
|
|
actions.front().serialized_entity_data.data()));
|
|
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
|
|
"hello");
|
|
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
|
|
"there");
|
|
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
|
|
"Kenobi");
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace libtextclassifier3
|