You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
663 lines
28 KiB
663 lines
28 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <jni.h>
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "utils/flatbuffers/mutable.h"
|
|
#include "utils/intents/intent-generator.h"
|
|
#include "utils/intents/remote-action-template.h"
|
|
#include "utils/java/jni-helper.h"
|
|
#include "utils/jvm-test-utils.h"
|
|
#include "utils/resources_generated.h"
|
|
#include "utils/testing/logging_event_listener.h"
|
|
#include "utils/variant.h"
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
#include "flatbuffers/reflection.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
|
|
using ::testing::ElementsAre;
|
|
using ::testing::IsEmpty;
|
|
using ::testing::SizeIs;
|
|
|
|
flatbuffers::DetachedBuffer BuildTestIntentFactoryModel(
|
|
const std::string& entity_type, const std::string& generator_code) {
|
|
// Test intent generation options.
|
|
IntentFactoryModelT options;
|
|
options.generator.emplace_back(new IntentFactoryModel_::IntentGeneratorT());
|
|
options.generator.back()->type = entity_type;
|
|
options.generator.back()->lua_template_generator = std::vector<unsigned char>(
|
|
generator_code.data(), generator_code.data() + generator_code.size());
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(IntentFactoryModel::Pack(builder, &options));
|
|
return builder.Release();
|
|
}
|
|
|
|
flatbuffers::DetachedBuffer BuildTestResources() {
|
|
// Custom string resources.
|
|
ResourcePoolT test_resources;
|
|
test_resources.locale.emplace_back(new LanguageTagT);
|
|
test_resources.locale.back()->language = "en";
|
|
test_resources.locale.emplace_back(new LanguageTagT);
|
|
test_resources.locale.back()->language = "de";
|
|
|
|
// Add `add_calendar_event`
|
|
test_resources.resource_entry.emplace_back(new ResourceEntryT);
|
|
test_resources.resource_entry.back()->name = "add_calendar_event";
|
|
|
|
// en
|
|
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
|
|
test_resources.resource_entry.back()->resource.back()->content = "Schedule";
|
|
test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
|
|
|
|
// Add `add_calendar_event_desc`
|
|
test_resources.resource_entry.emplace_back(new ResourceEntryT);
|
|
test_resources.resource_entry.back()->name = "add_calendar_event_desc";
|
|
|
|
// en
|
|
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
|
|
test_resources.resource_entry.back()->resource.back()->content =
|
|
"Schedule event for selected time";
|
|
test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
|
|
|
|
// Add `map`.
|
|
test_resources.resource_entry.emplace_back(new ResourceEntryT);
|
|
test_resources.resource_entry.back()->name = "map";
|
|
|
|
// en
|
|
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
|
|
test_resources.resource_entry.back()->resource.back()->content = "Map";
|
|
test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
|
|
|
|
// de
|
|
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
|
|
test_resources.resource_entry.back()->resource.back()->content = "Karte";
|
|
test_resources.resource_entry.back()->resource.back()->locale.push_back(1);
|
|
|
|
// Add `map_desc`.
|
|
test_resources.resource_entry.emplace_back(new ResourceEntryT);
|
|
test_resources.resource_entry.back()->name = "map_desc";
|
|
|
|
// en
|
|
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
|
|
test_resources.resource_entry.back()->resource.back()->content =
|
|
"Locate selected address";
|
|
test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
|
|
|
|
// de
|
|
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
|
|
test_resources.resource_entry.back()->resource.back()->content =
|
|
"Ausgewählte Adresse finden";
|
|
test_resources.resource_entry.back()->resource.back()->locale.push_back(1);
|
|
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
builder.Finish(ResourcePool::Pack(builder, &test_resources));
|
|
return builder.Release();
|
|
}
|
|
|
|
// Common methods for intent generator tests.
|
|
class IntentGeneratorTest : public testing::Test {
|
|
protected:
|
|
explicit IntentGeneratorTest()
|
|
: jni_cache_(JniCache::Create(GetJenv())),
|
|
resource_buffer_(BuildTestResources()),
|
|
resources_(
|
|
flatbuffers::GetRoot<ResourcePool>(resource_buffer_.data())) {}
|
|
|
|
const std::shared_ptr<JniCache> jni_cache_;
|
|
const flatbuffers::DetachedBuffer resource_buffer_;
|
|
const ResourcePool* resources_;
|
|
};
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesDefaultClassification) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("unused", "");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
/*options=*/flatbuffers::GetRoot<IntentFactoryModel>(
|
|
intent_factory_model.data()),
|
|
/*resources=*/resources_,
|
|
/*jni_cache=*/jni_cache_);
|
|
ClassificationResult classification;
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
/*device_locales=*/nullptr, classification, /*reference_time_ms_utc=*/0,
|
|
/*text=*/"", /*selection_indices=*/{kInvalidIndex, kInvalidIndex},
|
|
/*context=*/nullptr,
|
|
/*annotations_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, IsEmpty());
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, FailsGracefully) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("test", R"lua(
|
|
return {
|
|
{
|
|
-- Should fail, as no app GetAndroidContext() is provided.
|
|
data = external.android.package_name,
|
|
}
|
|
})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
ClassificationResult classification = {"test", 1.0};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_FALSE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
classification,
|
|
/*reference_time_ms_utc=*/0, "test", {0, 4}, /*context=*/nullptr,
|
|
/*annotations_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, IsEmpty());
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesEntityIntentGeneration) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("address", R"lua(
|
|
return {
|
|
{
|
|
title_without_entity = external.android.R.map,
|
|
title_with_entity = external.entity.text,
|
|
description = external.android.R.map_desc,
|
|
action = "android.intent.action.VIEW",
|
|
data = "geo:0,0?q=" ..
|
|
external.android.urlencode(external.entity.text),
|
|
}
|
|
})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
ClassificationResult classification = {"address", 1.0};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
classification,
|
|
/*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
|
|
GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
|
|
EXPECT_EQ(intents[0].title_with_entity.value(), "333 E Wonderview Ave");
|
|
EXPECT_EQ(intents[0].description.value(), "Locate selected address");
|
|
EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
|
|
EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesCallbacks) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("test", R"lua(
|
|
local test = external.entity["text"]
|
|
return {
|
|
{
|
|
data = "encoded=" .. external.android.urlencode(test),
|
|
category = { "test_category" },
|
|
extra = {
|
|
{ name = "package", string_value = external.android.package_name},
|
|
{ name = "scheme",
|
|
string_value = external.android.url_schema("https://google.com")},
|
|
{ name = "host",
|
|
string_value = external.android.url_host("https://google.com/search")},
|
|
{ name = "permission",
|
|
bool_value = external.android.user_restrictions["no_sms"] },
|
|
{ name = "language",
|
|
string_value = external.android.device_locales[1].language },
|
|
{ name = "description",
|
|
string_value = external.format("$1 $0", "hello", "world") },
|
|
},
|
|
request_code = external.hash(test)
|
|
}
|
|
})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
ClassificationResult classification = {"test", 1.0};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
classification,
|
|
/*reference_time_ms_utc=*/0, "this is a test", {0, 14},
|
|
GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_EQ(intents[0].data.value(), "encoded=this%20is%20a%20test");
|
|
EXPECT_THAT(intents[0].category, ElementsAre("test_category"));
|
|
EXPECT_THAT(intents[0].extra, SizeIs(6));
|
|
EXPECT_EQ(intents[0].extra["package"].ConstRefValue<std::string>(),
|
|
"com.google.android.textclassifier.tests"
|
|
);
|
|
EXPECT_EQ(intents[0].extra["scheme"].ConstRefValue<std::string>(), "https");
|
|
EXPECT_EQ(intents[0].extra["host"].ConstRefValue<std::string>(),
|
|
"google.com");
|
|
EXPECT_FALSE(intents[0].extra["permission"].Value<bool>());
|
|
EXPECT_EQ(intents[0].extra["language"].ConstRefValue<std::string>(), "en");
|
|
EXPECT_TRUE(intents[0].request_code.has_value());
|
|
EXPECT_EQ(intents[0].extra["description"].ConstRefValue<std::string>(),
|
|
"world hello");
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesActionIntentGeneration) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("view_map", R"lua(
|
|
return {
|
|
{
|
|
title_without_entity = external.android.R.map,
|
|
description = external.android.R.map_desc,
|
|
description_with_app_name = external.android.R.map,
|
|
action = "android.intent.action.VIEW",
|
|
data = "geo:0,0?q=" ..
|
|
external.android.urlencode(external.entity.annotation["location"].text),
|
|
}
|
|
})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
|
|
ActionSuggestionAnnotation annotation;
|
|
annotation.entity = {"address", 1.0};
|
|
annotation.span = {/*message_index=*/0,
|
|
/*span=*/{6, 11},
|
|
/*text=*/"there"};
|
|
annotation.name = "location";
|
|
ActionSuggestion suggestion = {/*response_text=""*/ "",
|
|
/*type=*/"view_map",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/0.0,
|
|
/*annotations=*/
|
|
{annotation}};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
suggestion, conversation, GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
|
|
EXPECT_EQ(intents[0].description.value(), "Locate selected address");
|
|
EXPECT_EQ(intents[0].description_with_app_name.value(), "Map");
|
|
EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
|
|
EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=there");
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesTimezoneAndReferenceTime) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("test", R"lua(
|
|
local conversation = external.conversation
|
|
return {
|
|
{
|
|
extra = {
|
|
{ name = "timezone", string_value = conversation[#conversation].timezone },
|
|
{ name = "num_messages", int_value = #conversation },
|
|
{ name = "reference_time", long_value = conversation[#conversation].time_ms_utc }
|
|
},
|
|
}
|
|
})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
Conversation conversation = {
|
|
{{/*user_id=*/0, "hello there", /*reference_time_ms_utc=*/0,
|
|
/*reference_timezone=*/"Testing/Test"},
|
|
{/*user_id=*/1, "general retesti", /*reference_time_ms_utc=*/1000,
|
|
/*reference_timezone=*/"Europe/Zurich"}}};
|
|
ActionSuggestion suggestion = {/*response_text=""*/ "",
|
|
/*type=*/"test",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/0.0,
|
|
/*annotations=*/
|
|
{}};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
suggestion, conversation, GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_EQ(intents[0].extra["timezone"].ConstRefValue<std::string>(),
|
|
"Europe/Zurich");
|
|
EXPECT_EQ(intents[0].extra["num_messages"].Value<int>(), 2);
|
|
EXPECT_EQ(intents[0].extra["reference_time"].Value<int64>(), 1000);
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesActionIntentGenerationMultipleAnnotations) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("create_event", R"lua(
|
|
return {
|
|
{
|
|
title_without_entity = external.android.R.add_calendar_event,
|
|
description = external.android.R.add_calendar_event_desc,
|
|
extra = {
|
|
{name = "time", string_value =
|
|
external.entity.annotation["time"].text},
|
|
{name = "location",
|
|
string_value = external.entity.annotation["location"].text},
|
|
}
|
|
}
|
|
})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
Conversation conversation = {{{/*user_id=*/1, "hello there at 1pm"}}};
|
|
ActionSuggestionAnnotation location_annotation, time_annotation;
|
|
location_annotation.entity = {"address", 1.0};
|
|
location_annotation.span = {/*message_index=*/0,
|
|
/*span=*/{6, 11},
|
|
/*text=*/"there"};
|
|
location_annotation.name = "location";
|
|
time_annotation.entity = {"datetime", 1.0};
|
|
time_annotation.span = {/*message_index=*/0,
|
|
/*span=*/{15, 18},
|
|
/*text=*/"1pm"};
|
|
time_annotation.name = "time";
|
|
ActionSuggestion suggestion = {/*response_text=""*/ "",
|
|
/*type=*/"create_event",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/0.0,
|
|
/*annotations=*/
|
|
{location_annotation, time_annotation}};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
suggestion, conversation, GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_EQ(intents[0].title_without_entity.value(), "Schedule");
|
|
EXPECT_THAT(intents[0].extra, SizeIs(2));
|
|
EXPECT_EQ(intents[0].extra["location"].ConstRefValue<std::string>(), "there");
|
|
EXPECT_EQ(intents[0].extra["time"].ConstRefValue<std::string>(), "1pm");
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest,
|
|
HandlesActionIntentGenerationMultipleAnnotationsWithIndices) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("time_range", R"lua(
|
|
return {
|
|
{
|
|
title_without_entity = "test",
|
|
description = "test",
|
|
extra = {
|
|
{name = "from", string_value = external.entity.annotation[1].text},
|
|
{name = "to", string_value = external.entity.annotation[2].text},
|
|
}
|
|
}
|
|
})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
Conversation conversation = {{{/*user_id=*/1, "from 1pm to 2pm"}}};
|
|
ActionSuggestionAnnotation from_annotation, to_annotation;
|
|
from_annotation.entity = {"datetime", 1.0};
|
|
from_annotation.span = {/*message_index=*/0,
|
|
/*span=*/{5, 8},
|
|
/*text=*/"1pm"};
|
|
to_annotation.entity = {"datetime", 1.0};
|
|
to_annotation.span = {/*message_index=*/0,
|
|
/*span=*/{12, 15},
|
|
/*text=*/"2pm"};
|
|
ActionSuggestion suggestion = {/*response_text=""*/ "",
|
|
/*type=*/"time_range",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/0.0,
|
|
/*annotations=*/
|
|
{from_annotation, to_annotation}};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
suggestion, conversation, GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_THAT(intents[0].extra, SizeIs(2));
|
|
EXPECT_EQ(intents[0].extra["from"].ConstRefValue<std::string>(), "1pm");
|
|
EXPECT_EQ(intents[0].extra["to"].ConstRefValue<std::string>(), "2pm");
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesResources) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("address", R"lua(
|
|
return {
|
|
{
|
|
title_without_entity = external.android.R.map,
|
|
description = external.android.R.map_desc,
|
|
action = "android.intent.action.VIEW",
|
|
data = "geo:0,0?q=" ..
|
|
external.android.urlencode(external.entity.text),
|
|
}
|
|
})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
resources_, jni_cache_);
|
|
ClassificationResult classification = {"address", 1.0};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "de-DE").ValueOrDie().get(),
|
|
classification,
|
|
/*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
|
|
GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_EQ(intents[0].title_without_entity.value(), "Karte");
|
|
EXPECT_EQ(intents[0].description.value(), "Ausgewählte Adresse finden");
|
|
EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
|
|
EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesIteration) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("iteration_test", R"lua(
|
|
local extra = {{ name = "length", int_value = #external.entity.annotation }}
|
|
for annotation_id, annotation in pairs(external.entity.annotation) do
|
|
table.insert(extra,
|
|
{ name = annotation.name,
|
|
string_value = annotation.text })
|
|
end
|
|
return {{ extra = extra }})lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
|
|
ActionSuggestionAnnotation location_annotation;
|
|
location_annotation.entity = {"address", 1.0};
|
|
location_annotation.span = {/*message_index=*/0,
|
|
/*span=*/{6, 11},
|
|
/*text=*/"there"};
|
|
location_annotation.name = "location";
|
|
ActionSuggestionAnnotation greeting_annotation;
|
|
greeting_annotation.entity = {"greeting", 1.0};
|
|
greeting_annotation.span = {/*message_index=*/0,
|
|
/*span=*/{0, 5},
|
|
/*text=*/"hello"};
|
|
greeting_annotation.name = "greeting";
|
|
ActionSuggestion suggestion = {/*response_text=""*/ "",
|
|
/*type=*/"iteration_test",
|
|
/*score=*/1.0,
|
|
/*priority_score=*/0.0,
|
|
/*annotations=*/
|
|
{location_annotation, greeting_annotation}};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
suggestion, conversation, GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr,
|
|
/*actions_entity_data_schema=*/nullptr, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_EQ(intents[0].extra["length"].Value<int>(), 2);
|
|
EXPECT_EQ(intents[0].extra["location"].ConstRefValue<std::string>(), "there");
|
|
EXPECT_EQ(intents[0].extra["greeting"].ConstRefValue<std::string>(), "hello");
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, HandlesEntityDataLookups) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("fake", R"lua(
|
|
local person = external.entity.person
|
|
return {
|
|
{
|
|
title_without_entity = "Add to contacts",
|
|
extra = {
|
|
{name = "name", string_value = string.lower(person.name)},
|
|
{name = "encoded_phone", string_value = external.android.urlencode(person.phone)},
|
|
{name = "age", int_value = person.age_years},
|
|
}
|
|
}
|
|
})lua");
|
|
|
|
// Create fake entity data schema meta data.
|
|
// Cannot use object oriented API here as that is not available for the
|
|
// reflection schema.
|
|
flatbuffers::FlatBufferBuilder schema_builder;
|
|
std::vector<flatbuffers::Offset<reflection::Field>> person_fields = {
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("name"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::String),
|
|
/*id=*/0,
|
|
/*offset=*/4),
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("phone"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::String),
|
|
/*id=*/1,
|
|
/*offset=*/6),
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("age_years"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::Int),
|
|
/*id=*/2,
|
|
/*offset=*/8),
|
|
};
|
|
std::vector<flatbuffers::Offset<reflection::Field>> entity_data_fields = {
|
|
reflection::CreateField(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("person"),
|
|
/*type=*/
|
|
reflection::CreateType(schema_builder,
|
|
/*base_type=*/reflection::Obj,
|
|
/*element=*/reflection::None,
|
|
/*index=*/1),
|
|
/*id=*/0,
|
|
/*offset=*/4)};
|
|
std::vector<flatbuffers::Offset<reflection::Enum>> enums;
|
|
std::vector<flatbuffers::Offset<reflection::Object>> objects = {
|
|
reflection::CreateObject(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("EntityData"),
|
|
/*fields=*/
|
|
schema_builder.CreateVectorOfSortedTables(&entity_data_fields)),
|
|
reflection::CreateObject(
|
|
schema_builder,
|
|
/*name=*/schema_builder.CreateString("person"),
|
|
/*fields=*/
|
|
schema_builder.CreateVectorOfSortedTables(&person_fields))};
|
|
schema_builder.Finish(reflection::CreateSchema(
|
|
schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
|
|
schema_builder.CreateVectorOfSortedTables(&enums),
|
|
/*(unused) file_ident=*/0,
|
|
/*(unused) file_ext=*/0,
|
|
/*root_table*/ objects[0]));
|
|
const reflection::Schema* entity_data_schema =
|
|
flatbuffers::GetRoot<reflection::Schema>(
|
|
schema_builder.GetBufferPointer());
|
|
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
|
|
ClassificationResult classification = {"fake", 1.0};
|
|
|
|
// Build test entity data.
|
|
MutableFlatbufferBuilder entity_data_builder(entity_data_schema);
|
|
std::unique_ptr<MutableFlatbuffer> entity_data_buffer =
|
|
entity_data_builder.NewRoot();
|
|
MutableFlatbuffer* person = entity_data_buffer->Mutable("person");
|
|
person->Set("name", "Kenobi");
|
|
person->Set("phone", "1 800 HIGHGROUND");
|
|
person->Set("age_years", 38);
|
|
classification.serialized_entity_data = entity_data_buffer->Serialize();
|
|
|
|
std::vector<RemoteActionTemplate> intents;
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
classification,
|
|
/*reference_time_ms_utc=*/0, "highground", {0, 10}, GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/entity_data_schema, &intents));
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
EXPECT_THAT(intents[0].extra, SizeIs(3));
|
|
EXPECT_EQ(intents[0].extra["name"].ConstRefValue<std::string>(), "kenobi");
|
|
EXPECT_EQ(intents[0].extra["encoded_phone"].ConstRefValue<std::string>(),
|
|
"1%20800%20HIGHGROUND");
|
|
EXPECT_EQ(intents[0].extra["age"].Value<int>(), 38);
|
|
}
|
|
|
|
TEST_F(IntentGeneratorTest, ReadExtras) {
|
|
flatbuffers::DetachedBuffer intent_factory_model =
|
|
BuildTestIntentFactoryModel("test", R"lua(
|
|
return {
|
|
{
|
|
extra = {
|
|
{ name = "languages", string_array_value = {"en", "zh"}},
|
|
{ name = "scores", float_array_value = {0.6, 0.4}},
|
|
{ name = "ints", int_array_value = {7, 2, 1}},
|
|
{ name = "bundle",
|
|
named_variant_array_value =
|
|
{
|
|
{ name = "inner_string", string_value = "a" },
|
|
{ name = "inner_int", int_value = 42 }
|
|
}
|
|
}
|
|
}
|
|
}}
|
|
)lua");
|
|
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
|
|
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
|
|
/*resources=*/resources_, jni_cache_);
|
|
const ClassificationResult classification = {"test", 1.0};
|
|
std::vector<RemoteActionTemplate> intents;
|
|
|
|
EXPECT_TRUE(generator->GenerateIntents(
|
|
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
|
|
classification,
|
|
/*reference_time_ms_utc=*/0, "test", {0, 4}, GetAndroidContext(),
|
|
/*annotations_entity_data_schema=*/nullptr, &intents));
|
|
|
|
EXPECT_THAT(intents, SizeIs(1));
|
|
RemoteActionTemplate intent = intents[0];
|
|
EXPECT_THAT(intent.extra, SizeIs(4));
|
|
EXPECT_THAT(
|
|
intent.extra["languages"].ConstRefValue<std::vector<std::string>>(),
|
|
ElementsAre("en", "zh"));
|
|
EXPECT_THAT(intent.extra["scores"].ConstRefValue<std::vector<float>>(),
|
|
ElementsAre(0.6, 0.4));
|
|
EXPECT_THAT(intent.extra["ints"].ConstRefValue<std::vector<int>>(),
|
|
ElementsAre(7, 2, 1));
|
|
const std::map<std::string, Variant>& map =
|
|
intent.extra["bundle"].ConstRefValue<std::map<std::string, Variant>>();
|
|
EXPECT_THAT(map, SizeIs(2));
|
|
EXPECT_EQ(map.at("inner_string").ConstRefValue<std::string>(), "a");
|
|
EXPECT_EQ(map.at("inner_int").Value<int>(), 42);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace libtextclassifier3
|