You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
362 lines
12 KiB
362 lines
12 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "actions/ranker.h"
|
|
|
|
#include <functional>
|
|
#include <set>
|
|
#include <vector>
|
|
|
|
#if !defined(TC3_DISABLE_LUA)
|
|
#include "actions/lua-ranker.h"
|
|
#endif
|
|
#include "actions/zlib-utils.h"
|
|
#include "annotator/types.h"
|
|
#include "utils/base/logging.h"
|
|
#if !defined(TC3_DISABLE_LUA)
|
|
#include "utils/lua-utils.h"
|
|
#endif
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
|
|
void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
|
|
std::sort(actions->begin(), actions->end(),
|
|
[](const ActionSuggestion& a, const ActionSuggestion& b) {
|
|
return a.score > b.score ||
|
|
(a.score >= b.score && a.type < b.type);
|
|
});
|
|
}
|
|
|
|
template <typename T>
|
|
int Compare(const T& left, const T& right) {
|
|
if (left < right) {
|
|
return -1;
|
|
}
|
|
if (left > right) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template <>
|
|
int Compare(const std::string& left, const std::string& right) {
|
|
return left.compare(right);
|
|
}
|
|
|
|
template <>
|
|
int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
|
|
if (const int value = Compare(span.message_index, other.message_index)) {
|
|
return value;
|
|
}
|
|
if (const int value = Compare(span.span.first, other.span.first)) {
|
|
return value;
|
|
}
|
|
if (const int value = Compare(span.span.second, other.span.second)) {
|
|
return value;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
|
|
return Compare(span, other) == 0;
|
|
}
|
|
|
|
bool TextSpansIntersect(const MessageTextSpan& span,
|
|
const MessageTextSpan& other) {
|
|
return span.message_index == other.message_index &&
|
|
SpansOverlap(span.span, other.span);
|
|
}
|
|
|
|
template <>
|
|
int Compare(const ActionSuggestionAnnotation& annotation,
|
|
const ActionSuggestionAnnotation& other) {
|
|
if (const int value = Compare(annotation.span, other.span)) {
|
|
return value;
|
|
}
|
|
if (const int value = Compare(annotation.name, other.name)) {
|
|
return value;
|
|
}
|
|
if (const int value =
|
|
Compare(annotation.entity.collection, other.entity.collection)) {
|
|
return value;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// Checks whether two annotations can be considered equivalent.
|
|
bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
|
|
const ActionSuggestionAnnotation& other) {
|
|
return Compare(annotation, other) == 0;
|
|
}
|
|
|
|
// Compares actions based on annotations.
|
|
int CompareAnnotationsOnly(const ActionSuggestion& action,
|
|
const ActionSuggestion& other) {
|
|
if (const int value =
|
|
Compare(action.annotations.size(), other.annotations.size())) {
|
|
return value;
|
|
}
|
|
for (int i = 0; i < action.annotations.size(); i++) {
|
|
if (const int value =
|
|
Compare(action.annotations[i], other.annotations[i])) {
|
|
return value;
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// Checks whether two actions have the same annotations.
|
|
bool HaveEquivalentAnnotations(const ActionSuggestion& action,
|
|
const ActionSuggestion& other) {
|
|
return CompareAnnotationsOnly(action, other) == 0;
|
|
}
|
|
|
|
template <>
|
|
int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
|
|
if (const int value = Compare(action.type, other.type)) {
|
|
return value;
|
|
}
|
|
if (const int value = Compare(action.response_text, other.response_text)) {
|
|
return value;
|
|
}
|
|
if (const int value = Compare(action.serialized_entity_data,
|
|
other.serialized_entity_data)) {
|
|
return value;
|
|
}
|
|
return CompareAnnotationsOnly(action, other);
|
|
}
|
|
|
|
// Checks whether two action suggestions can be considered equivalent.
|
|
bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
|
|
const ActionSuggestion& other) {
|
|
return Compare(action, other) == 0;
|
|
}
|
|
|
|
// Checks whether any action is equivalent to the given one.
|
|
bool IsAnyActionEquivalent(const ActionSuggestion& action,
|
|
const std::vector<ActionSuggestion>& actions) {
|
|
for (const ActionSuggestion& other : actions) {
|
|
if (IsEquivalentActionSuggestion(action, other)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool IsConflicting(const ActionSuggestionAnnotation& annotation,
|
|
const ActionSuggestionAnnotation& other) {
|
|
// Two annotations are conflicting if they are different but refer to
|
|
// overlapping spans in the conversation.
|
|
return (!IsEquivalentActionAnnotation(annotation, other) &&
|
|
TextSpansIntersect(annotation.span, other.span));
|
|
}
|
|
|
|
// Checks whether two action suggestions can be considered conflicting.
|
|
bool IsConflictingActionSuggestion(const ActionSuggestion& action,
|
|
const ActionSuggestion& other) {
|
|
// Actions are considered conflicting, iff they refer to the same text span,
|
|
// but were not generated from the same annotation.
|
|
if (action.annotations.empty() || other.annotations.empty()) {
|
|
return false;
|
|
}
|
|
for (const ActionSuggestionAnnotation& annotation : action.annotations) {
|
|
for (const ActionSuggestionAnnotation& other_annotation :
|
|
other.annotations) {
|
|
if (IsConflicting(annotation, other_annotation)) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Checks whether any action is considered conflicting with the given one.
|
|
bool IsAnyActionConflicting(const ActionSuggestion& action,
|
|
const std::vector<ActionSuggestion>& actions) {
|
|
for (const ActionSuggestion& other : actions) {
|
|
if (IsConflictingActionSuggestion(action, other)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<ActionsSuggestionsRanker>
|
|
ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
|
|
const RankingOptions* options, ZlibDecompressor* decompressor,
|
|
const std::string& smart_reply_action_type) {
|
|
auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
|
|
new ActionsSuggestionsRanker(options, smart_reply_action_type));
|
|
|
|
if (!ranker->InitializeAndValidate(decompressor)) {
|
|
TC3_LOG(ERROR) << "Could not initialize action ranker.";
|
|
return nullptr;
|
|
}
|
|
|
|
return ranker;
|
|
}
|
|
|
|
bool ActionsSuggestionsRanker::InitializeAndValidate(
|
|
ZlibDecompressor* decompressor) {
|
|
if (options_ == nullptr) {
|
|
TC3_LOG(ERROR) << "No ranking options specified.";
|
|
return false;
|
|
}
|
|
|
|
#if !defined(TC3_DISABLE_LUA)
|
|
std::string lua_ranking_script;
|
|
if (GetUncompressedString(options_->lua_ranking_script(),
|
|
options_->compressed_lua_ranking_script(),
|
|
decompressor, &lua_ranking_script) &&
|
|
!lua_ranking_script.empty()) {
|
|
if (!Compile(lua_ranking_script, &lua_bytecode_)) {
|
|
TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
|
|
return false;
|
|
}
|
|
}
|
|
#endif
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ActionsSuggestionsRanker::RankActions(
|
|
const Conversation& conversation, ActionsSuggestionsResponse* response,
|
|
const reflection::Schema* entity_data_schema,
|
|
const reflection::Schema* annotations_entity_data_schema) const {
|
|
if (options_->deduplicate_suggestions() ||
|
|
options_->deduplicate_suggestions_by_span()) {
|
|
// First order suggestions by priority score for deduplication.
|
|
std::sort(
|
|
response->actions.begin(), response->actions.end(),
|
|
[](const ActionSuggestion& a, const ActionSuggestion& b) {
|
|
return a.priority_score > b.priority_score ||
|
|
(a.priority_score >= b.priority_score && a.score > b.score);
|
|
});
|
|
|
|
// Deduplicate, keeping the higher score actions.
|
|
if (options_->deduplicate_suggestions()) {
|
|
std::vector<ActionSuggestion> deduplicated_actions;
|
|
for (const ActionSuggestion& candidate : response->actions) {
|
|
// Check whether we already have an equivalent action.
|
|
if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
|
|
deduplicated_actions.push_back(std::move(candidate));
|
|
}
|
|
}
|
|
response->actions = std::move(deduplicated_actions);
|
|
}
|
|
|
|
// Resolve conflicts between conflicting actions referring to the same
|
|
// text span.
|
|
if (options_->deduplicate_suggestions_by_span()) {
|
|
std::vector<ActionSuggestion> deduplicated_actions;
|
|
for (const ActionSuggestion& candidate : response->actions) {
|
|
// Check whether we already have a conflicting action.
|
|
if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
|
|
deduplicated_actions.push_back(std::move(candidate));
|
|
}
|
|
}
|
|
response->actions = std::move(deduplicated_actions);
|
|
}
|
|
}
|
|
|
|
// Suppress smart replies if actions are present.
|
|
if (options_->suppress_smart_replies_with_actions()) {
|
|
std::vector<ActionSuggestion> non_smart_reply_actions;
|
|
for (const ActionSuggestion& action : response->actions) {
|
|
if (action.type != smart_reply_action_type_) {
|
|
non_smart_reply_actions.push_back(std::move(action));
|
|
}
|
|
}
|
|
response->actions = std::move(non_smart_reply_actions);
|
|
}
|
|
|
|
// Group by annotation if specified.
|
|
if (options_->group_by_annotations()) {
|
|
auto group_id = std::map<
|
|
ActionSuggestion, int,
|
|
std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
|
|
[](const ActionSuggestion& action, const ActionSuggestion& other) {
|
|
return (CompareAnnotationsOnly(action, other) < 0);
|
|
}};
|
|
typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
|
|
std::vector<ActionSuggestionGroup> groups;
|
|
|
|
// Group actions by the annotation set they are based of.
|
|
for (const ActionSuggestion& action : response->actions) {
|
|
// Treat actions with no annotations idependently.
|
|
if (action.annotations.empty()) {
|
|
groups.emplace_back(1, action);
|
|
continue;
|
|
}
|
|
|
|
auto it = group_id.find(action);
|
|
if (it != group_id.end()) {
|
|
groups[it->second].push_back(action);
|
|
} else {
|
|
group_id[action] = groups.size();
|
|
groups.emplace_back(1, action);
|
|
}
|
|
}
|
|
|
|
// Sort within each group by score.
|
|
for (std::vector<ActionSuggestion>& group : groups) {
|
|
SortByScoreAndType(&group);
|
|
}
|
|
|
|
// Sort groups by maximum score.
|
|
std::sort(groups.begin(), groups.end(),
|
|
[](const std::vector<ActionSuggestion>& a,
|
|
const std::vector<ActionSuggestion>& b) {
|
|
return a.begin()->score > b.begin()->score ||
|
|
(a.begin()->score >= b.begin()->score &&
|
|
a.begin()->type < b.begin()->type);
|
|
});
|
|
|
|
// Flatten result.
|
|
const size_t num_actions = response->actions.size();
|
|
response->actions.clear();
|
|
response->actions.reserve(num_actions);
|
|
for (const std::vector<ActionSuggestion>& actions : groups) {
|
|
response->actions.insert(response->actions.end(), actions.begin(),
|
|
actions.end());
|
|
}
|
|
|
|
} else {
|
|
// Order suggestions independently by score.
|
|
SortByScoreAndType(&response->actions);
|
|
}
|
|
|
|
#if !defined(TC3_DISABLE_LUA)
|
|
// Run lua ranking snippet, if provided.
|
|
if (!lua_bytecode_.empty()) {
|
|
auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
|
|
conversation, lua_bytecode_, entity_data_schema,
|
|
annotations_entity_data_schema, response);
|
|
if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
|
|
TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
|
|
return false;
|
|
}
|
|
}
|
|
#endif
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace libtextclassifier3
|