You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
3.6 KiB
118 lines
3.6 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2018 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#include "actions/lua-ranker.h"
|
||
|
|
||
|
#include "utils/base/logging.h"
|
||
|
#include "utils/lua-utils.h"
|
||
|
|
||
|
#ifdef __cplusplus
|
||
|
extern "C" {
|
||
|
#endif
|
||
|
#include "lauxlib.h"
|
||
|
#include "lualib.h"
|
||
|
#ifdef __cplusplus
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
namespace libtextclassifier3 {
|
||
|
|
||
|
std::unique_ptr<ActionsSuggestionsLuaRanker>
|
||
|
ActionsSuggestionsLuaRanker::Create(
|
||
|
const Conversation& conversation, const std::string& ranker_code,
|
||
|
const reflection::Schema* entity_data_schema,
|
||
|
const reflection::Schema* annotations_entity_data_schema,
|
||
|
ActionsSuggestionsResponse* response) {
|
||
|
auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
|
||
|
new ActionsSuggestionsLuaRanker(
|
||
|
conversation, ranker_code, entity_data_schema,
|
||
|
annotations_entity_data_schema, response));
|
||
|
if (!ranker->Initialize()) {
|
||
|
TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
|
||
|
return nullptr;
|
||
|
}
|
||
|
return ranker;
|
||
|
}
|
||
|
|
||
|
bool ActionsSuggestionsLuaRanker::Initialize() {
|
||
|
return RunProtected([this] {
|
||
|
LoadDefaultLibraries();
|
||
|
|
||
|
// Expose generated actions.
|
||
|
PushActions(&response_->actions, actions_entity_data_schema_,
|
||
|
annotations_entity_data_schema_);
|
||
|
lua_setglobal(state_, "actions");
|
||
|
|
||
|
// Expose conversation message stream.
|
||
|
PushConversation(&conversation_.messages,
|
||
|
annotations_entity_data_schema_);
|
||
|
lua_setglobal(state_, "messages");
|
||
|
return LUA_OK;
|
||
|
}) == LUA_OK;
|
||
|
}
|
||
|
|
||
|
int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
|
||
|
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
|
||
|
TC3_LOG(ERROR) << "Expected actions table, got: "
|
||
|
<< lua_type(state_, /*idx=*/-1);
|
||
|
lua_pop(state_, 1);
|
||
|
lua_error(state_);
|
||
|
return LUA_ERRRUN;
|
||
|
}
|
||
|
std::vector<ActionSuggestion> ranked_actions;
|
||
|
lua_pushnil(state_);
|
||
|
while (Next(/*index=*/-2)) {
|
||
|
const int action_id = Read<int>(/*index=*/-1) - 1;
|
||
|
lua_pop(state_, 1);
|
||
|
if (action_id < 0 || action_id >= response_->actions.size()) {
|
||
|
TC3_LOG(ERROR) << "Invalid action index: " << action_id;
|
||
|
lua_error(state_);
|
||
|
return LUA_ERRRUN;
|
||
|
}
|
||
|
ranked_actions.push_back(response_->actions[action_id]);
|
||
|
}
|
||
|
lua_pop(state_, 1);
|
||
|
response_->actions = ranked_actions;
|
||
|
return LUA_OK;
|
||
|
}
|
||
|
|
||
|
bool ActionsSuggestionsLuaRanker::RankActions() {
|
||
|
if (response_->actions.empty()) {
|
||
|
// Nothing to do.
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
|
||
|
/*name=*/nullptr) != LUA_OK) {
|
||
|
TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
|
||
|
TC3_LOG(ERROR) << "Could not run ranking snippet.";
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
|
||
|
LUA_OK) {
|
||
|
TC3_LOG(ERROR) << "Could not read lua result.";
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
} // namespace libtextclassifier3
|