You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
5.7 KiB

/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/utils.h"
#include "annotator/collections.h"
#include "utils/base/logging.h"
#include "utils/normalization.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
// Name for a datetime annotation that only includes time but no date.
const std::string& kTimeAnnotation =
*[]() { return new std::string("time"); }();
void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
MutableFlatbuffer* entity_data,
ActionSuggestion* suggestion) {
if (action != nullptr) {
suggestion->score = action->score();
suggestion->priority_score = action->priority_score();
if (action->type() != nullptr) {
suggestion->type = action->type()->str();
}
if (action->response_text() != nullptr) {
suggestion->response_text = action->response_text()->str();
}
if (action->serialized_entity_data() != nullptr) {
TC3_CHECK_NE(entity_data, nullptr);
entity_data->MergeFromSerializedFlatbuffer(
StringPiece(action->serialized_entity_data()->data(),
action->serialized_entity_data()->size()));
}
if (action->entity_data() != nullptr) {
TC3_CHECK_NE(entity_data, nullptr);
entity_data->MergeFrom(
reinterpret_cast<const flatbuffers::Table*>(action->entity_data()));
}
}
if (entity_data != nullptr && entity_data->HasExplicitlySetFields()) {
suggestion->serialized_entity_data = entity_data->Serialize();
}
}
void SuggestTextRepliesFromCapturingMatch(
const MutableFlatbufferBuilder* entity_data_builder,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
const UnicodeText& match_text, const std::string& smart_reply_action_type,
std::vector<ActionSuggestion>* actions) {
if (group->text_reply() != nullptr) {
ActionSuggestion suggestion;
suggestion.response_text = match_text.ToUTF8String();
suggestion.type = smart_reply_action_type;
std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder != nullptr ? entity_data_builder->NewRoot()
: nullptr;
FillSuggestionFromSpec(group->text_reply(), entity_data.get(), &suggestion);
actions->push_back(suggestion);
}
}
UnicodeText NormalizeMatchText(
const UniLib& unilib,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
StringPiece match_text) {
return NormalizeMatchText(unilib, group,
UTF8ToUnicodeText(match_text, /*do_copy=*/false));
}
UnicodeText NormalizeMatchText(
const UniLib& unilib,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
const UnicodeText match_text) {
if (group->normalization_options() == nullptr) {
return match_text;
}
return NormalizeText(unilib, group->normalization_options(), match_text);
}
bool FillAnnotationFromCapturingMatch(
const CodepointSpan& span,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
const int message_index, StringPiece match_text,
ActionSuggestionAnnotation* annotation) {
if (group->annotation_name() == nullptr &&
group->annotation_type() == nullptr) {
return false;
}
annotation->span.span = span;
annotation->span.message_index = message_index;
annotation->span.text = match_text.ToString();
if (group->annotation_name() != nullptr) {
annotation->name = group->annotation_name()->str();
}
if (group->annotation_type() != nullptr) {
annotation->entity.collection = group->annotation_type()->str();
}
return true;
}
bool MergeEntityDataFromCapturingMatch(
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
StringPiece match_text, MutableFlatbuffer* buffer) {
if (group->entity_field() != nullptr) {
if (!buffer->ParseAndSet(group->entity_field(), match_text.ToString())) {
TC3_LOG(ERROR) << "Could not set entity data from rule capturing group.";
return false;
}
}
if (group->entity_data() != nullptr) {
if (!buffer->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
group->entity_data()))) {
TC3_LOG(ERROR) << "Could not set entity data for capturing match.";
return false;
}
}
return true;
}
void ConvertDatetimeToTime(std::vector<AnnotatedSpan>* annotations) {
for (int i = 0; i < annotations->size(); i++) {
ClassificationResult* classification =
&(*annotations)[i].classification.front();
// Specialize datetime annotation to time annotation if no date
// component is present.
if (classification->collection == Collections::DateTime() &&
classification->datetime_parse_result.IsSet()) {
bool has_only_time = true;
for (const DatetimeComponent& component :
classification->datetime_parse_result.datetime_components) {
if (component.component_type !=
DatetimeComponent::ComponentType::UNSPECIFIED &&
component.component_type < DatetimeComponent::ComponentType::HOUR) {
has_only_time = false;
break;
}
}
if (has_only_time) {
classification->collection = kTimeAnnotation;
}
}
}
}
} // namespace libtextclassifier3