You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

663 lines
28 KiB

/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <jni.h>
#include <memory>
#include <vector>
#include "utils/flatbuffers/mutable.h"
#include "utils/intents/intent-generator.h"
#include "utils/intents/remote-action-template.h"
#include "utils/java/jni-helper.h"
#include "utils/jvm-test-utils.h"
#include "utils/resources_generated.h"
#include "utils/testing/logging_event_listener.h"
#include "utils/variant.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "flatbuffers/reflection.h"
namespace libtextclassifier3 {
namespace {
using ::testing::ElementsAre;
using ::testing::IsEmpty;
using ::testing::SizeIs;
flatbuffers::DetachedBuffer BuildTestIntentFactoryModel(
const std::string& entity_type, const std::string& generator_code) {
// Test intent generation options.
IntentFactoryModelT options;
options.generator.emplace_back(new IntentFactoryModel_::IntentGeneratorT());
options.generator.back()->type = entity_type;
options.generator.back()->lua_template_generator = std::vector<unsigned char>(
generator_code.data(), generator_code.data() + generator_code.size());
flatbuffers::FlatBufferBuilder builder;
builder.Finish(IntentFactoryModel::Pack(builder, &options));
return builder.Release();
}
flatbuffers::DetachedBuffer BuildTestResources() {
// Custom string resources.
ResourcePoolT test_resources;
test_resources.locale.emplace_back(new LanguageTagT);
test_resources.locale.back()->language = "en";
test_resources.locale.emplace_back(new LanguageTagT);
test_resources.locale.back()->language = "de";
// Add `add_calendar_event`
test_resources.resource_entry.emplace_back(new ResourceEntryT);
test_resources.resource_entry.back()->name = "add_calendar_event";
// en
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
test_resources.resource_entry.back()->resource.back()->content = "Schedule";
test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
// Add `add_calendar_event_desc`
test_resources.resource_entry.emplace_back(new ResourceEntryT);
test_resources.resource_entry.back()->name = "add_calendar_event_desc";
// en
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
test_resources.resource_entry.back()->resource.back()->content =
"Schedule event for selected time";
test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
// Add `map`.
test_resources.resource_entry.emplace_back(new ResourceEntryT);
test_resources.resource_entry.back()->name = "map";
// en
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
test_resources.resource_entry.back()->resource.back()->content = "Map";
test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
// de
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
test_resources.resource_entry.back()->resource.back()->content = "Karte";
test_resources.resource_entry.back()->resource.back()->locale.push_back(1);
// Add `map_desc`.
test_resources.resource_entry.emplace_back(new ResourceEntryT);
test_resources.resource_entry.back()->name = "map_desc";
// en
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
test_resources.resource_entry.back()->resource.back()->content =
"Locate selected address";
test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
// de
test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
test_resources.resource_entry.back()->resource.back()->content =
"Ausgewählte Adresse finden";
test_resources.resource_entry.back()->resource.back()->locale.push_back(1);
flatbuffers::FlatBufferBuilder builder;
builder.Finish(ResourcePool::Pack(builder, &test_resources));
return builder.Release();
}
// Common methods for intent generator tests.
class IntentGeneratorTest : public testing::Test {
protected:
explicit IntentGeneratorTest()
: jni_cache_(JniCache::Create(GetJenv())),
resource_buffer_(BuildTestResources()),
resources_(
flatbuffers::GetRoot<ResourcePool>(resource_buffer_.data())) {}
const std::shared_ptr<JniCache> jni_cache_;
const flatbuffers::DetachedBuffer resource_buffer_;
const ResourcePool* resources_;
};
TEST_F(IntentGeneratorTest, HandlesDefaultClassification) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("unused", "");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
/*options=*/flatbuffers::GetRoot<IntentFactoryModel>(
intent_factory_model.data()),
/*resources=*/resources_,
/*jni_cache=*/jni_cache_);
ClassificationResult classification;
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
/*device_locales=*/nullptr, classification, /*reference_time_ms_utc=*/0,
/*text=*/"", /*selection_indices=*/{kInvalidIndex, kInvalidIndex},
/*context=*/nullptr,
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, IsEmpty());
}
TEST_F(IntentGeneratorTest, FailsGracefully) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("test", R"lua(
return {
{
-- Should fail, as no app GetAndroidContext() is provided.
data = external.android.package_name,
}
})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
ClassificationResult classification = {"test", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_FALSE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "test", {0, 4}, /*context=*/nullptr,
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, IsEmpty());
}
TEST_F(IntentGeneratorTest, HandlesEntityIntentGeneration) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("address", R"lua(
return {
{
title_without_entity = external.android.R.map,
title_with_entity = external.entity.text,
description = external.android.R.map_desc,
action = "android.intent.action.VIEW",
data = "geo:0,0?q=" ..
external.android.urlencode(external.entity.text),
}
})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
ClassificationResult classification = {"address", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
EXPECT_EQ(intents[0].title_with_entity.value(), "333 E Wonderview Ave");
EXPECT_EQ(intents[0].description.value(), "Locate selected address");
EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
}
TEST_F(IntentGeneratorTest, HandlesCallbacks) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("test", R"lua(
local test = external.entity["text"]
return {
{
data = "encoded=" .. external.android.urlencode(test),
category = { "test_category" },
extra = {
{ name = "package", string_value = external.android.package_name},
{ name = "scheme",
string_value = external.android.url_schema("https://google.com")},
{ name = "host",
string_value = external.android.url_host("https://google.com/search")},
{ name = "permission",
bool_value = external.android.user_restrictions["no_sms"] },
{ name = "language",
string_value = external.android.device_locales[1].language },
{ name = "description",
string_value = external.format("$1 $0", "hello", "world") },
},
request_code = external.hash(test)
}
})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
ClassificationResult classification = {"test", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "this is a test", {0, 14},
GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].data.value(), "encoded=this%20is%20a%20test");
EXPECT_THAT(intents[0].category, ElementsAre("test_category"));
EXPECT_THAT(intents[0].extra, SizeIs(6));
EXPECT_EQ(intents[0].extra["package"].ConstRefValue<std::string>(),
"com.google.android.textclassifier.tests"
);
EXPECT_EQ(intents[0].extra["scheme"].ConstRefValue<std::string>(), "https");
EXPECT_EQ(intents[0].extra["host"].ConstRefValue<std::string>(),
"google.com");
EXPECT_FALSE(intents[0].extra["permission"].Value<bool>());
EXPECT_EQ(intents[0].extra["language"].ConstRefValue<std::string>(), "en");
EXPECT_TRUE(intents[0].request_code.has_value());
EXPECT_EQ(intents[0].extra["description"].ConstRefValue<std::string>(),
"world hello");
}
TEST_F(IntentGeneratorTest, HandlesActionIntentGeneration) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("view_map", R"lua(
return {
{
title_without_entity = external.android.R.map,
description = external.android.R.map_desc,
description_with_app_name = external.android.R.map,
action = "android.intent.action.VIEW",
data = "geo:0,0?q=" ..
external.android.urlencode(external.entity.annotation["location"].text),
}
})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
ActionSuggestionAnnotation annotation;
annotation.entity = {"address", 1.0};
annotation.span = {/*message_index=*/0,
/*span=*/{6, 11},
/*text=*/"there"};
annotation.name = "location";
ActionSuggestion suggestion = {/*response_text=""*/ "",
/*type=*/"view_map",
/*score=*/1.0,
/*priority_score=*/0.0,
/*annotations=*/
{annotation}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
EXPECT_EQ(intents[0].description.value(), "Locate selected address");
EXPECT_EQ(intents[0].description_with_app_name.value(), "Map");
EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=there");
}
TEST_F(IntentGeneratorTest, HandlesTimezoneAndReferenceTime) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("test", R"lua(
local conversation = external.conversation
return {
{
extra = {
{ name = "timezone", string_value = conversation[#conversation].timezone },
{ name = "num_messages", int_value = #conversation },
{ name = "reference_time", long_value = conversation[#conversation].time_ms_utc }
},
}
})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
Conversation conversation = {
{{/*user_id=*/0, "hello there", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Testing/Test"},
{/*user_id=*/1, "general retesti", /*reference_time_ms_utc=*/1000,
/*reference_timezone=*/"Europe/Zurich"}}};
ActionSuggestion suggestion = {/*response_text=""*/ "",
/*type=*/"test",
/*score=*/1.0,
/*priority_score=*/0.0,
/*annotations=*/
{}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].extra["timezone"].ConstRefValue<std::string>(),
"Europe/Zurich");
EXPECT_EQ(intents[0].extra["num_messages"].Value<int>(), 2);
EXPECT_EQ(intents[0].extra["reference_time"].Value<int64>(), 1000);
}
TEST_F(IntentGeneratorTest, HandlesActionIntentGenerationMultipleAnnotations) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("create_event", R"lua(
return {
{
title_without_entity = external.android.R.add_calendar_event,
description = external.android.R.add_calendar_event_desc,
extra = {
{name = "time", string_value =
external.entity.annotation["time"].text},
{name = "location",
string_value = external.entity.annotation["location"].text},
}
}
})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
Conversation conversation = {{{/*user_id=*/1, "hello there at 1pm"}}};
ActionSuggestionAnnotation location_annotation, time_annotation;
location_annotation.entity = {"address", 1.0};
location_annotation.span = {/*message_index=*/0,
/*span=*/{6, 11},
/*text=*/"there"};
location_annotation.name = "location";
time_annotation.entity = {"datetime", 1.0};
time_annotation.span = {/*message_index=*/0,
/*span=*/{15, 18},
/*text=*/"1pm"};
time_annotation.name = "time";
ActionSuggestion suggestion = {/*response_text=""*/ "",
/*type=*/"create_event",
/*score=*/1.0,
/*priority_score=*/0.0,
/*annotations=*/
{location_annotation, time_annotation}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].title_without_entity.value(), "Schedule");
EXPECT_THAT(intents[0].extra, SizeIs(2));
EXPECT_EQ(intents[0].extra["location"].ConstRefValue<std::string>(), "there");
EXPECT_EQ(intents[0].extra["time"].ConstRefValue<std::string>(), "1pm");
}
TEST_F(IntentGeneratorTest,
HandlesActionIntentGenerationMultipleAnnotationsWithIndices) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("time_range", R"lua(
return {
{
title_without_entity = "test",
description = "test",
extra = {
{name = "from", string_value = external.entity.annotation[1].text},
{name = "to", string_value = external.entity.annotation[2].text},
}
}
})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
Conversation conversation = {{{/*user_id=*/1, "from 1pm to 2pm"}}};
ActionSuggestionAnnotation from_annotation, to_annotation;
from_annotation.entity = {"datetime", 1.0};
from_annotation.span = {/*message_index=*/0,
/*span=*/{5, 8},
/*text=*/"1pm"};
to_annotation.entity = {"datetime", 1.0};
to_annotation.span = {/*message_index=*/0,
/*span=*/{12, 15},
/*text=*/"2pm"};
ActionSuggestion suggestion = {/*response_text=""*/ "",
/*type=*/"time_range",
/*score=*/1.0,
/*priority_score=*/0.0,
/*annotations=*/
{from_annotation, to_annotation}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_THAT(intents[0].extra, SizeIs(2));
EXPECT_EQ(intents[0].extra["from"].ConstRefValue<std::string>(), "1pm");
EXPECT_EQ(intents[0].extra["to"].ConstRefValue<std::string>(), "2pm");
}
TEST_F(IntentGeneratorTest, HandlesResources) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("address", R"lua(
return {
{
title_without_entity = external.android.R.map,
description = external.android.R.map_desc,
action = "android.intent.action.VIEW",
data = "geo:0,0?q=" ..
external.android.urlencode(external.entity.text),
}
})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
resources_, jni_cache_);
ClassificationResult classification = {"address", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "de-DE").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].title_without_entity.value(), "Karte");
EXPECT_EQ(intents[0].description.value(), "Ausgewählte Adresse finden");
EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
}
TEST_F(IntentGeneratorTest, HandlesIteration) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("iteration_test", R"lua(
local extra = {{ name = "length", int_value = #external.entity.annotation }}
for annotation_id, annotation in pairs(external.entity.annotation) do
table.insert(extra,
{ name = annotation.name,
string_value = annotation.text })
end
return {{ extra = extra }})lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
ActionSuggestionAnnotation location_annotation;
location_annotation.entity = {"address", 1.0};
location_annotation.span = {/*message_index=*/0,
/*span=*/{6, 11},
/*text=*/"there"};
location_annotation.name = "location";
ActionSuggestionAnnotation greeting_annotation;
greeting_annotation.entity = {"greeting", 1.0};
greeting_annotation.span = {/*message_index=*/0,
/*span=*/{0, 5},
/*text=*/"hello"};
greeting_annotation.name = "greeting";
ActionSuggestion suggestion = {/*response_text=""*/ "",
/*type=*/"iteration_test",
/*score=*/1.0,
/*priority_score=*/0.0,
/*annotations=*/
{location_annotation, greeting_annotation}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].extra["length"].Value<int>(), 2);
EXPECT_EQ(intents[0].extra["location"].ConstRefValue<std::string>(), "there");
EXPECT_EQ(intents[0].extra["greeting"].ConstRefValue<std::string>(), "hello");
}
TEST_F(IntentGeneratorTest, HandlesEntityDataLookups) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("fake", R"lua(
local person = external.entity.person
return {
{
title_without_entity = "Add to contacts",
extra = {
{name = "name", string_value = string.lower(person.name)},
{name = "encoded_phone", string_value = external.android.urlencode(person.phone)},
{name = "age", int_value = person.age_years},
}
}
})lua");
// Create fake entity data schema meta data.
// Cannot use object oriented API here as that is not available for the
// reflection schema.
flatbuffers::FlatBufferBuilder schema_builder;
std::vector<flatbuffers::Offset<reflection::Field>> person_fields = {
reflection::CreateField(
schema_builder,
/*name=*/schema_builder.CreateString("name"),
/*type=*/
reflection::CreateType(schema_builder,
/*base_type=*/reflection::String),
/*id=*/0,
/*offset=*/4),
reflection::CreateField(
schema_builder,
/*name=*/schema_builder.CreateString("phone"),
/*type=*/
reflection::CreateType(schema_builder,
/*base_type=*/reflection::String),
/*id=*/1,
/*offset=*/6),
reflection::CreateField(
schema_builder,
/*name=*/schema_builder.CreateString("age_years"),
/*type=*/
reflection::CreateType(schema_builder,
/*base_type=*/reflection::Int),
/*id=*/2,
/*offset=*/8),
};
std::vector<flatbuffers::Offset<reflection::Field>> entity_data_fields = {
reflection::CreateField(
schema_builder,
/*name=*/schema_builder.CreateString("person"),
/*type=*/
reflection::CreateType(schema_builder,
/*base_type=*/reflection::Obj,
/*element=*/reflection::None,
/*index=*/1),
/*id=*/0,
/*offset=*/4)};
std::vector<flatbuffers::Offset<reflection::Enum>> enums;
std::vector<flatbuffers::Offset<reflection::Object>> objects = {
reflection::CreateObject(
schema_builder,
/*name=*/schema_builder.CreateString("EntityData"),
/*fields=*/
schema_builder.CreateVectorOfSortedTables(&entity_data_fields)),
reflection::CreateObject(
schema_builder,
/*name=*/schema_builder.CreateString("person"),
/*fields=*/
schema_builder.CreateVectorOfSortedTables(&person_fields))};
schema_builder.Finish(reflection::CreateSchema(
schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
schema_builder.CreateVectorOfSortedTables(&enums),
/*(unused) file_ident=*/0,
/*(unused) file_ext=*/0,
/*root_table*/ objects[0]));
const reflection::Schema* entity_data_schema =
flatbuffers::GetRoot<reflection::Schema>(
schema_builder.GetBufferPointer());
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
ClassificationResult classification = {"fake", 1.0};
// Build test entity data.
MutableFlatbufferBuilder entity_data_builder(entity_data_schema);
std::unique_ptr<MutableFlatbuffer> entity_data_buffer =
entity_data_builder.NewRoot();
MutableFlatbuffer* person = entity_data_buffer->Mutable("person");
person->Set("name", "Kenobi");
person->Set("phone", "1 800 HIGHGROUND");
person->Set("age_years", 38);
classification.serialized_entity_data = entity_data_buffer->Serialize();
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "highground", {0, 10}, GetAndroidContext(),
/*annotations_entity_data_schema=*/entity_data_schema, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_THAT(intents[0].extra, SizeIs(3));
EXPECT_EQ(intents[0].extra["name"].ConstRefValue<std::string>(), "kenobi");
EXPECT_EQ(intents[0].extra["encoded_phone"].ConstRefValue<std::string>(),
"1%20800%20HIGHGROUND");
EXPECT_EQ(intents[0].extra["age"].Value<int>(), 38);
}
TEST_F(IntentGeneratorTest, ReadExtras) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("test", R"lua(
return {
{
extra = {
{ name = "languages", string_array_value = {"en", "zh"}},
{ name = "scores", float_array_value = {0.6, 0.4}},
{ name = "ints", int_array_value = {7, 2, 1}},
{ name = "bundle",
named_variant_array_value =
{
{ name = "inner_string", string_value = "a" },
{ name = "inner_int", int_value = 42 }
}
}
}
}}
)lua");
std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
/*resources=*/resources_, jni_cache_);
const ClassificationResult classification = {"test", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "test", {0, 4}, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
RemoteActionTemplate intent = intents[0];
EXPECT_THAT(intent.extra, SizeIs(4));
EXPECT_THAT(
intent.extra["languages"].ConstRefValue<std::vector<std::string>>(),
ElementsAre("en", "zh"));
EXPECT_THAT(intent.extra["scores"].ConstRefValue<std::vector<float>>(),
ElementsAre(0.6, 0.4));
EXPECT_THAT(intent.extra["ints"].ConstRefValue<std::vector<int>>(),
ElementsAre(7, 2, 1));
const std::map<std::string, Variant>& map =
intent.extra["bundle"].ConstRefValue<std::map<std::string, Variant>>();
EXPECT_THAT(map, SizeIs(2));
EXPECT_EQ(map.at("inner_string").ConstRefValue<std::string>(), "a");
EXPECT_EQ(map.at("inner_int").Value<int>(), 42);
}
} // namespace
} // namespace libtextclassifier3