/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/lua-actions.h" #include #include #include "actions/test-utils.h" #include "actions/types.h" #include "utils/tflite-model-executor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace libtextclassifier3 { namespace { using testing::ElementsAre; TEST(LuaActions, SimpleAction) { Conversation conversation; const std::string test_snippet = R"( return {{ type = "test_action" }} )"; std::vector actions; EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( test_snippet, conversation, /*model_executor=*/nullptr, /*model_spec=*/nullptr, /*interpreter=*/nullptr, /*actions_entity_data_schema=*/nullptr, /*annotations_entity_data_schema=*/nullptr) ->SuggestActions(&actions)); EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action"))); } TEST(LuaActions, ConversationActions) { Conversation conversation; conversation.messages.push_back({/*user_id=*/0, "hello there!"}); conversation.messages.push_back({/*user_id=*/1, "general kenobi!"}); const std::string test_snippet = R"( local actions = {} for i, message in pairs(messages) do if i < #messages then if message.text == "hello there!" and messages[i+1].text == "general kenobi!" then table.insert(actions, { type = "text_reply", response_text = "you are a bold one!" }) end if message.text == "i am the senate!" and messages[i+1].text == "not yet!" then table.insert(actions, { type = "text_reply", response_text = "it's treason then" }) end end end return actions; )"; std::vector actions; EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( test_snippet, conversation, /*model_executor=*/nullptr, /*model_spec=*/nullptr, /*interpreter=*/nullptr, /*actions_entity_data_schema=*/nullptr, /*annotations_entity_data_schema=*/nullptr) ->SuggestActions(&actions)); EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!"))); } TEST(LuaActions, SimpleModelAction) { Conversation conversation; const std::string test_snippet = R"( if #model.actions_scores == 0 then return {{ type = "test_action" }} end return {} )"; std::vector actions; EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( test_snippet, conversation, /*model_executor=*/nullptr, /*model_spec=*/nullptr, /*interpreter=*/nullptr, /*actions_entity_data_schema=*/nullptr, /*annotations_entity_data_schema=*/nullptr) ->SuggestActions(&actions)); EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action"))); } TEST(LuaActions, SimpleModelRepliesAction) { Conversation conversation; const std::string test_snippet = R"( if #model.reply == 0 then return {{ type = "test_action" }} end return {} )"; std::vector actions; EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( test_snippet, conversation, /*model_executor=*/nullptr, /*model_spec=*/nullptr, /*interpreter=*/nullptr, /*actions_entity_data_schema=*/nullptr, /*annotations_entity_data_schema=*/nullptr) ->SuggestActions(&actions)); EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action"))); } TEST(LuaActions, AnnotationActions) { AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = {ClassificationResult("address", 1.0)}; Conversation conversation = {{{/*user_id=*/1, "are you at home?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}; const std::string test_snippet = R"( local actions = {} local last_message = messages[#messages] for i, annotation in pairs(last_message.annotation) do if #annotation.classification > 0 then if annotation.classification[1].collection == "address" then local text = string.sub(last_message.text, annotation.span["begin"] + 1, annotation.span["end"]) table.insert(actions, { type = "text_reply", response_text = "i am at " .. text, annotation = {{ name = "location", span = { text = text }, entity = annotation.classification[1] }}, }) end end end return actions; )"; std::vector actions; EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( test_snippet, conversation, /*model_executor=*/nullptr, /*model_spec=*/nullptr, /*interpreter=*/nullptr, /*actions_entity_data_schema=*/nullptr, /*annotations_entity_data_schema=*/nullptr) ->SuggestActions(&actions)); EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home"))); EXPECT_EQ("address", actions[0].annotations[0].entity.collection); } TEST(LuaActions, EntityData) { std::string test_schema = TestEntityDataSchema(); Conversation conversation = {{{/*user_id=*/1, "hello there"}}}; const std::string test_snippet = R"( return {{ type = "test", entity = { greeting = "hello", location = "there", person = "Kenobi", }, }}; )"; std::vector actions; EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( test_snippet, conversation, /*model_executor=*/nullptr, /*model_spec=*/nullptr, /*interpreter=*/nullptr, /*actions_entity_data_schema=*/ flatbuffers::GetRoot(test_schema.data()), /*annotations_entity_data_schema=*/nullptr) ->SuggestActions(&actions)); EXPECT_THAT(actions, testing::SizeIs(1)); EXPECT_EQ("test", actions.front().type); const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( actions.front().serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/4)->str(), "hello"); EXPECT_EQ(entity->GetPointer(/*field=*/6)->str(), "there"); EXPECT_EQ(entity->GetPointer(/*field=*/8)->str(), "Kenobi"); } } // namespace } // namespace libtextclassifier3