You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
658 lines
25 KiB
658 lines
25 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "utils/lua-utils.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
|
|
{LUA_TABLIBNAME, luaopen_table},
|
|
{LUA_STRLIBNAME, luaopen_string},
|
|
{LUA_MATHLIBNAME, luaopen_math},
|
|
{nullptr, nullptr}};
|
|
|
|
static constexpr const char kTextKey[] = "text";
|
|
static constexpr const char kTimeUsecKey[] = "parsed_time_ms_utc";
|
|
static constexpr const char kGranularityKey[] = "granularity";
|
|
static constexpr const char kCollectionKey[] = "collection";
|
|
static constexpr const char kNameKey[] = "name";
|
|
static constexpr const char kScoreKey[] = "score";
|
|
static constexpr const char kPriorityScoreKey[] = "priority_score";
|
|
static constexpr const char kTypeKey[] = "type";
|
|
static constexpr const char kResponseTextKey[] = "response_text";
|
|
static constexpr const char kAnnotationKey[] = "annotation";
|
|
static constexpr const char kSpanKey[] = "span";
|
|
static constexpr const char kMessageKey[] = "message";
|
|
static constexpr const char kBeginKey[] = "begin";
|
|
static constexpr const char kEndKey[] = "end";
|
|
static constexpr const char kClassificationKey[] = "classification";
|
|
static constexpr const char kSerializedEntity[] = "serialized_entity";
|
|
static constexpr const char kEntityKey[] = "entity";
|
|
|
|
// Implementation of a lua_Writer that appends the data to a string.
|
|
int LuaStringWriter(lua_State* state, const void* data, size_t size,
|
|
void* result) {
|
|
std::string* const result_string = static_cast<std::string*>(result);
|
|
result_string->insert(result_string->size(), static_cast<const char*>(data),
|
|
size);
|
|
return LUA_OK;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
|
|
|
|
LuaEnvironment::~LuaEnvironment() {
|
|
if (state_ != nullptr) {
|
|
lua_close(state_);
|
|
}
|
|
}
|
|
|
|
void LuaEnvironment::PushFlatbuffer(const reflection::Schema* schema,
|
|
const reflection::Object* type,
|
|
const flatbuffers::Table* table) const {
|
|
PushLazyObject(
|
|
std::bind(&LuaEnvironment::GetField, this, schema, type, table));
|
|
}
|
|
|
|
int LuaEnvironment::GetField(const reflection::Schema* schema,
|
|
const reflection::Object* type,
|
|
const flatbuffers::Table* table) const {
|
|
const char* field_name = lua_tostring(state_, /*idx=*/kIndexStackTop);
|
|
const reflection::Field* field = type->fields()->LookupByKey(field_name);
|
|
if (field == nullptr) {
|
|
lua_error(state_);
|
|
return 0;
|
|
}
|
|
// Provide primitive fields directly.
|
|
const reflection::BaseType field_type = field->type()->base_type();
|
|
switch (field_type) {
|
|
case reflection::Bool:
|
|
Push(table->GetField<bool>(field->offset(), field->default_integer()));
|
|
break;
|
|
case reflection::UByte:
|
|
Push(table->GetField<uint8>(field->offset(), field->default_integer()));
|
|
break;
|
|
case reflection::Byte:
|
|
Push(table->GetField<int8>(field->offset(), field->default_integer()));
|
|
break;
|
|
case reflection::Int:
|
|
Push(table->GetField<int32>(field->offset(), field->default_integer()));
|
|
break;
|
|
case reflection::UInt:
|
|
Push(table->GetField<uint32>(field->offset(), field->default_integer()));
|
|
break;
|
|
case reflection::Long:
|
|
Push(table->GetField<int64>(field->offset(), field->default_integer()));
|
|
break;
|
|
case reflection::ULong:
|
|
Push(table->GetField<uint64>(field->offset(), field->default_integer()));
|
|
break;
|
|
case reflection::Float:
|
|
Push(table->GetField<float>(field->offset(), field->default_real()));
|
|
break;
|
|
case reflection::Double:
|
|
Push(table->GetField<double>(field->offset(), field->default_real()));
|
|
break;
|
|
case reflection::String: {
|
|
Push(table->GetPointer<const flatbuffers::String*>(field->offset()));
|
|
break;
|
|
}
|
|
case reflection::Obj: {
|
|
const flatbuffers::Table* field_table =
|
|
table->GetPointer<const flatbuffers::Table*>(field->offset());
|
|
if (field_table == nullptr) {
|
|
// Field was not set in entity data.
|
|
return 0;
|
|
}
|
|
const reflection::Object* field_type =
|
|
schema->objects()->Get(field->type()->index());
|
|
PushFlatbuffer(schema, field_type, field_table);
|
|
break;
|
|
}
|
|
case reflection::Vector: {
|
|
const flatbuffers::Vector<flatbuffers::Offset<void>>* field_vector =
|
|
table->GetPointer<
|
|
const flatbuffers::Vector<flatbuffers::Offset<void>>*>(
|
|
field->offset());
|
|
if (field_vector == nullptr) {
|
|
// Repeated field was not set in flatbuffer.
|
|
PushEmptyVector();
|
|
break;
|
|
}
|
|
switch (field->type()->element()) {
|
|
case reflection::Bool:
|
|
PushRepeatedField(table->GetPointer<const flatbuffers::Vector<bool>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::UByte:
|
|
PushRepeatedField(
|
|
table->GetPointer<const flatbuffers::Vector<uint8>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::Byte:
|
|
PushRepeatedField(table->GetPointer<const flatbuffers::Vector<int8>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::Int:
|
|
PushRepeatedField(
|
|
table->GetPointer<const flatbuffers::Vector<int32>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::UInt:
|
|
PushRepeatedField(
|
|
table->GetPointer<const flatbuffers::Vector<uint32>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::Long:
|
|
PushRepeatedField(
|
|
table->GetPointer<const flatbuffers::Vector<int64>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::ULong:
|
|
PushRepeatedField(
|
|
table->GetPointer<const flatbuffers::Vector<uint64>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::Float:
|
|
PushRepeatedField(
|
|
table->GetPointer<const flatbuffers::Vector<float>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::Double:
|
|
PushRepeatedField(
|
|
table->GetPointer<const flatbuffers::Vector<double>*>(
|
|
field->offset()));
|
|
break;
|
|
case reflection::String:
|
|
PushRepeatedField(
|
|
table->GetPointer<const flatbuffers::Vector<
|
|
flatbuffers::Offset<flatbuffers::String>>*>(field->offset()));
|
|
break;
|
|
case reflection::Obj:
|
|
PushRepeatedFlatbufferField(
|
|
schema, schema->objects()->Get(field->type()->index()),
|
|
table->GetPointer<const flatbuffers::Vector<
|
|
flatbuffers::Offset<flatbuffers::Table>>*>(field->offset()));
|
|
break;
|
|
default:
|
|
TC3_LOG(ERROR) << "Unsupported repeated type: "
|
|
<< field->type()->element();
|
|
lua_error(state_);
|
|
return 0;
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
TC3_LOG(ERROR) << "Unsupported type: " << field_type;
|
|
lua_error(state_);
|
|
return 0;
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
int LuaEnvironment::ReadFlatbuffer(const int index,
|
|
MutableFlatbuffer* buffer) const {
|
|
if (buffer == nullptr) {
|
|
TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
|
|
lua_error(state_);
|
|
return LUA_ERRRUN;
|
|
}
|
|
if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
|
|
TC3_LOG(ERROR) << "Expected table, got: "
|
|
<< lua_type(state_, /*idx=*/kIndexStackTop);
|
|
lua_error(state_);
|
|
return LUA_ERRRUN;
|
|
}
|
|
|
|
lua_pushnil(state_);
|
|
while (Next(index - 1)) {
|
|
const StringPiece key = ReadString(/*index=*/index - 1);
|
|
const reflection::Field* field = buffer->GetFieldOrNull(key);
|
|
if (field == nullptr) {
|
|
TC3_LOG(ERROR) << "Unknown field: " << key;
|
|
lua_error(state_);
|
|
return LUA_ERRRUN;
|
|
}
|
|
switch (field->type()->base_type()) {
|
|
case reflection::Obj:
|
|
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer->Mutable(field));
|
|
break;
|
|
case reflection::Bool:
|
|
buffer->Set(field, Read<bool>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::Byte:
|
|
buffer->Set(field, Read<int8>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::UByte:
|
|
buffer->Set(field, Read<uint8>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::Int:
|
|
buffer->Set(field, Read<int32>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::UInt:
|
|
buffer->Set(field, Read<uint32>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::Long:
|
|
buffer->Set(field, Read<int64>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::ULong:
|
|
buffer->Set(field, Read<uint64>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::Float:
|
|
buffer->Set(field, Read<float>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::Double:
|
|
buffer->Set(field, Read<double>(/*index=*/kIndexStackTop));
|
|
break;
|
|
case reflection::String: {
|
|
buffer->Set(field, ReadString(/*index=*/kIndexStackTop));
|
|
break;
|
|
}
|
|
case reflection::Vector: {
|
|
// Read repeated field.
|
|
switch (field->type()->element()) {
|
|
case reflection::Bool:
|
|
ReadRepeatedField<bool>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::Byte:
|
|
ReadRepeatedField<int8>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::UByte:
|
|
ReadRepeatedField<uint8>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::Int:
|
|
ReadRepeatedField<int32>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::UInt:
|
|
ReadRepeatedField<uint32>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::Long:
|
|
ReadRepeatedField<int64>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::ULong:
|
|
ReadRepeatedField<uint64>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::Float:
|
|
ReadRepeatedField<float>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::Double:
|
|
ReadRepeatedField<double>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::String:
|
|
ReadRepeatedField<std::string>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
case reflection::Obj:
|
|
ReadRepeatedField<MutableFlatbuffer>(/*index=*/kIndexStackTop,
|
|
buffer->Repeated(field));
|
|
break;
|
|
default:
|
|
TC3_LOG(ERROR) << "Unsupported repeated field type: "
|
|
<< field->type()->element();
|
|
lua_error(state_);
|
|
return LUA_ERRRUN;
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
|
|
lua_error(state_);
|
|
return LUA_ERRRUN;
|
|
}
|
|
lua_pop(state_, 1);
|
|
}
|
|
return LUA_OK;
|
|
}
|
|
|
|
void LuaEnvironment::LoadDefaultLibraries() {
|
|
for (const luaL_Reg* lib = defaultlibs; lib->func; lib++) {
|
|
luaL_requiref(state_, lib->name, lib->func, 1);
|
|
lua_pop(state_, 1); // Remove lib.
|
|
}
|
|
}
|
|
|
|
StringPiece LuaEnvironment::ReadString(const int index) const {
|
|
size_t length = 0;
|
|
const char* data = lua_tolstring(state_, index, &length);
|
|
return StringPiece(data, length);
|
|
}
|
|
|
|
void LuaEnvironment::PushString(const StringPiece str) const {
|
|
lua_pushlstring(state_, str.data(), str.size());
|
|
}
|
|
|
|
bool LuaEnvironment::Compile(StringPiece snippet, std::string* bytecode) const {
|
|
if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
|
|
/*name=*/nullptr) != LUA_OK) {
|
|
TC3_LOG(ERROR) << "Could not compile lua snippet: "
|
|
<< ReadString(/*index=*/kIndexStackTop);
|
|
lua_pop(state_, 1);
|
|
return false;
|
|
}
|
|
if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
|
|
TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
|
|
lua_pop(state_, 1);
|
|
return false;
|
|
}
|
|
lua_pop(state_, 1);
|
|
return true;
|
|
}
|
|
|
|
void LuaEnvironment::PushAnnotation(
|
|
const ClassificationResult& classification,
|
|
const reflection::Schema* entity_data_schema) const {
|
|
if (entity_data_schema == nullptr ||
|
|
classification.serialized_entity_data.empty()) {
|
|
// Empty table.
|
|
lua_newtable(state_);
|
|
} else {
|
|
PushFlatbuffer(entity_data_schema,
|
|
flatbuffers::GetRoot<flatbuffers::Table>(
|
|
classification.serialized_entity_data.data()));
|
|
}
|
|
Push(classification.datetime_parse_result.time_ms_utc);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTimeUsecKey);
|
|
Push(classification.datetime_parse_result.granularity);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kGranularityKey);
|
|
Push(classification.collection);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kCollectionKey);
|
|
Push(classification.score);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
|
|
Push(classification.serialized_entity_data);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSerializedEntity);
|
|
}
|
|
|
|
void LuaEnvironment::PushAnnotation(
|
|
const ClassificationResult& classification, StringPiece text,
|
|
const reflection::Schema* entity_data_schema) const {
|
|
PushAnnotation(classification, entity_data_schema);
|
|
Push(text);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTextKey);
|
|
}
|
|
|
|
void LuaEnvironment::PushAnnotation(
|
|
const ActionSuggestionAnnotation& annotation,
|
|
const reflection::Schema* entity_data_schema) const {
|
|
PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema);
|
|
PushString(annotation.name);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kNameKey);
|
|
{
|
|
lua_newtable(state_);
|
|
Push(annotation.span.message_index);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kMessageKey);
|
|
Push(annotation.span.span.first);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
|
|
Push(annotation.span.span.second);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
|
|
}
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
|
|
}
|
|
|
|
void LuaEnvironment::PushAnnotatedSpan(
|
|
const AnnotatedSpan& annotated_span,
|
|
const reflection::Schema* entity_data_schema) const {
|
|
lua_newtable(state_);
|
|
{
|
|
lua_newtable(state_);
|
|
Push(annotated_span.span.first);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
|
|
Push(annotated_span.span.second);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
|
|
}
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
|
|
PushAnnotations(&annotated_span.classification, entity_data_schema);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kClassificationKey);
|
|
}
|
|
|
|
void LuaEnvironment::PushAnnotatedSpans(
|
|
const std::vector<AnnotatedSpan>* annotated_spans,
|
|
const reflection::Schema* entity_data_schema) const {
|
|
PushIterator(annotated_spans ? annotated_spans->size() : 0,
|
|
[this, annotated_spans, entity_data_schema](const int64 index) {
|
|
PushAnnotatedSpan(annotated_spans->at(index),
|
|
entity_data_schema);
|
|
return 1;
|
|
});
|
|
}
|
|
|
|
MessageTextSpan LuaEnvironment::ReadSpan() const {
|
|
MessageTextSpan span;
|
|
lua_pushnil(state_);
|
|
while (Next(/*index=*/kIndexStackTop - 1)) {
|
|
const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
|
|
if (key.Equals(kMessageKey)) {
|
|
span.message_index = Read<int>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kBeginKey)) {
|
|
span.span.first = Read<int>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kEndKey)) {
|
|
span.span.second = Read<int>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kTextKey)) {
|
|
span.text = Read<std::string>(/*index=*/kIndexStackTop);
|
|
} else {
|
|
TC3_LOG(INFO) << "Unknown span field: " << key;
|
|
}
|
|
lua_pop(state_, 1);
|
|
}
|
|
return span;
|
|
}
|
|
|
|
int LuaEnvironment::ReadAnnotations(
|
|
const reflection::Schema* entity_data_schema,
|
|
std::vector<ActionSuggestionAnnotation>* annotations) const {
|
|
if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
|
|
TC3_LOG(ERROR) << "Expected annotations table, got: "
|
|
<< lua_type(state_, /*idx=*/kIndexStackTop);
|
|
lua_pop(state_, 1);
|
|
lua_error(state_);
|
|
return LUA_ERRRUN;
|
|
}
|
|
|
|
// Read actions.
|
|
lua_pushnil(state_);
|
|
while (Next(/*index=*/kIndexStackTop - 1)) {
|
|
if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
|
|
TC3_LOG(ERROR) << "Expected annotation table, got: "
|
|
<< lua_type(state_, /*idx=*/kIndexStackTop);
|
|
lua_pop(state_, 1);
|
|
continue;
|
|
}
|
|
annotations->push_back(ReadAnnotation(entity_data_schema));
|
|
lua_pop(state_, 1);
|
|
}
|
|
return LUA_OK;
|
|
}
|
|
|
|
ActionSuggestionAnnotation LuaEnvironment::ReadAnnotation(
|
|
const reflection::Schema* entity_data_schema) const {
|
|
ActionSuggestionAnnotation annotation;
|
|
lua_pushnil(state_);
|
|
while (Next(/*index=*/kIndexStackTop - 1)) {
|
|
const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
|
|
if (key.Equals(kNameKey)) {
|
|
annotation.name = Read<std::string>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kSpanKey)) {
|
|
annotation.span = ReadSpan();
|
|
} else if (key.Equals(kEntityKey)) {
|
|
annotation.entity = ReadClassificationResult(entity_data_schema);
|
|
} else {
|
|
TC3_LOG(ERROR) << "Unknown annotation field: " << key;
|
|
}
|
|
lua_pop(state_, 1);
|
|
}
|
|
return annotation;
|
|
}
|
|
|
|
ClassificationResult LuaEnvironment::ReadClassificationResult(
|
|
const reflection::Schema* entity_data_schema) const {
|
|
ClassificationResult classification;
|
|
lua_pushnil(state_);
|
|
while (Next(/*index=*/kIndexStackTop - 1)) {
|
|
const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
|
|
if (key.Equals(kCollectionKey)) {
|
|
classification.collection = Read<std::string>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kScoreKey)) {
|
|
classification.score = Read<float>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kTimeUsecKey)) {
|
|
classification.datetime_parse_result.time_ms_utc =
|
|
Read<int64>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kGranularityKey)) {
|
|
classification.datetime_parse_result.granularity =
|
|
static_cast<DatetimeGranularity>(
|
|
lua_tonumber(state_, /*idx=*/kIndexStackTop));
|
|
} else if (key.Equals(kSerializedEntity)) {
|
|
classification.serialized_entity_data =
|
|
Read<std::string>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kEntityKey)) {
|
|
auto buffer = MutableFlatbufferBuilder(entity_data_schema).NewRoot();
|
|
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
|
|
classification.serialized_entity_data = buffer->Serialize();
|
|
} else {
|
|
TC3_LOG(INFO) << "Unknown classification result field: " << key;
|
|
}
|
|
lua_pop(state_, 1);
|
|
}
|
|
return classification;
|
|
}
|
|
|
|
void LuaEnvironment::PushAction(
|
|
const ActionSuggestion& action,
|
|
const reflection::Schema* actions_entity_data_schema,
|
|
const reflection::Schema* annotations_entity_data_schema) const {
|
|
if (actions_entity_data_schema == nullptr ||
|
|
action.serialized_entity_data.empty()) {
|
|
// Empty table.
|
|
lua_newtable(state_);
|
|
} else {
|
|
PushFlatbuffer(actions_entity_data_schema,
|
|
flatbuffers::GetRoot<flatbuffers::Table>(
|
|
action.serialized_entity_data.data()));
|
|
}
|
|
PushString(action.type);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTypeKey);
|
|
PushString(action.response_text);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kResponseTextKey);
|
|
Push(action.score);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
|
|
Push(action.priority_score);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kPriorityScoreKey);
|
|
PushAnnotations(&action.annotations, annotations_entity_data_schema);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kAnnotationKey);
|
|
}
|
|
|
|
void LuaEnvironment::PushActions(
|
|
const std::vector<ActionSuggestion>* actions,
|
|
const reflection::Schema* actions_entity_data_schema,
|
|
const reflection::Schema* annotations_entity_data_schema) const {
|
|
PushIterator(actions ? actions->size() : 0,
|
|
[this, actions, actions_entity_data_schema,
|
|
annotations_entity_data_schema](const int64 index) {
|
|
PushAction(actions->at(index), actions_entity_data_schema,
|
|
annotations_entity_data_schema);
|
|
return 1;
|
|
});
|
|
}
|
|
|
|
ActionSuggestion LuaEnvironment::ReadAction(
|
|
const reflection::Schema* actions_entity_data_schema,
|
|
const reflection::Schema* annotations_entity_data_schema) const {
|
|
ActionSuggestion action;
|
|
lua_pushnil(state_);
|
|
while (Next(/*index=*/kIndexStackTop - 1)) {
|
|
const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
|
|
if (key.Equals(kResponseTextKey)) {
|
|
action.response_text = Read<std::string>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kTypeKey)) {
|
|
action.type = Read<std::string>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kScoreKey)) {
|
|
action.score = Read<float>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kPriorityScoreKey)) {
|
|
action.priority_score = Read<float>(/*index=*/kIndexStackTop);
|
|
} else if (key.Equals(kAnnotationKey)) {
|
|
ReadAnnotations(actions_entity_data_schema, &action.annotations);
|
|
} else if (key.Equals(kEntityKey)) {
|
|
auto buffer =
|
|
MutableFlatbufferBuilder(actions_entity_data_schema).NewRoot();
|
|
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
|
|
action.serialized_entity_data = buffer->Serialize();
|
|
} else {
|
|
TC3_LOG(INFO) << "Unknown action field: " << key;
|
|
}
|
|
lua_pop(state_, 1);
|
|
}
|
|
return action;
|
|
}
|
|
|
|
int LuaEnvironment::ReadActions(
|
|
const reflection::Schema* actions_entity_data_schema,
|
|
const reflection::Schema* annotations_entity_data_schema,
|
|
std::vector<ActionSuggestion>* actions) const {
|
|
// Read actions.
|
|
lua_pushnil(state_);
|
|
while (Next(/*index=*/kIndexStackTop - 1)) {
|
|
if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
|
|
TC3_LOG(ERROR) << "Expected action table, got: "
|
|
<< lua_type(state_, /*idx=*/kIndexStackTop);
|
|
lua_pop(state_, 1);
|
|
continue;
|
|
}
|
|
actions->push_back(
|
|
ReadAction(actions_entity_data_schema, annotations_entity_data_schema));
|
|
lua_pop(state_, /*n=*/1);
|
|
}
|
|
lua_pop(state_, /*n=*/1);
|
|
|
|
return LUA_OK;
|
|
}
|
|
|
|
void LuaEnvironment::PushConversation(
|
|
const std::vector<ConversationMessage>* conversation,
|
|
const reflection::Schema* annotations_entity_data_schema) const {
|
|
PushIterator(
|
|
conversation ? conversation->size() : 0,
|
|
[this, conversation, annotations_entity_data_schema](const int64 index) {
|
|
const ConversationMessage& message = conversation->at(index);
|
|
lua_newtable(state_);
|
|
Push(message.user_id);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "user_id");
|
|
Push(message.text);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "text");
|
|
Push(message.reference_time_ms_utc);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "time_ms_utc");
|
|
Push(message.reference_timezone);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "timezone");
|
|
PushAnnotatedSpans(&message.annotations,
|
|
annotations_entity_data_schema);
|
|
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "annotation");
|
|
return 1;
|
|
});
|
|
}
|
|
|
|
bool Compile(StringPiece snippet, std::string* bytecode) {
|
|
return LuaEnvironment().Compile(snippet, bytecode);
|
|
}
|
|
|
|
} // namespace libtextclassifier3
|