You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
5.1 KiB
122 lines
5.1 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "annotator/datetime/grammar-parser.h"
|
|
|
|
#include <set>
|
|
#include <unordered_set>
|
|
|
|
#include "annotator/datetime/datetime-grounder.h"
|
|
#include "annotator/types.h"
|
|
#include "utils/grammar/analyzer.h"
|
|
#include "utils/grammar/evaluated-derivation.h"
|
|
#include "utils/grammar/parsing/derivation.h"
|
|
|
|
using ::libtextclassifier3::grammar::EvaluatedDerivation;
|
|
using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
|
|
|
|
namespace libtextclassifier3 {
|
|
|
|
GrammarDatetimeParser::GrammarDatetimeParser(
|
|
const grammar::Analyzer& analyzer,
|
|
const DatetimeGrounder& datetime_grounder,
|
|
const float target_classification_score, const float priority_score)
|
|
: analyzer_(analyzer),
|
|
datetime_grounder_(datetime_grounder),
|
|
target_classification_score_(target_classification_score),
|
|
priority_score_(priority_score) {}
|
|
|
|
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
|
|
const std::string& input, const int64 reference_time_ms_utc,
|
|
const std::string& reference_timezone, const LocaleList& locale_list,
|
|
ModeFlag mode, AnnotationUsecase annotation_usecase,
|
|
bool anchor_start_end) const {
|
|
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
|
|
reference_time_ms_utc, reference_timezone, locale_list, mode,
|
|
annotation_usecase, anchor_start_end);
|
|
}
|
|
|
|
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
|
|
const UnicodeText& input, const int64 reference_time_ms_utc,
|
|
const std::string& reference_timezone, const LocaleList& locale_list,
|
|
ModeFlag mode, AnnotationUsecase annotation_usecase,
|
|
bool anchor_start_end) const {
|
|
std::vector<DatetimeParseResultSpan> results;
|
|
UnsafeArena arena(/*block_size=*/16 << 10);
|
|
std::vector<Locale> locales = locale_list.GetLocales();
|
|
// If the locale list is empty then datetime regex expression will still
|
|
// execute but in grammar based parser the rules are associated with local
|
|
// and engine will not run if the locale list is empty. In an unlikely
|
|
// scenario when locale is not mentioned fallback to en-*.
|
|
if (locales.empty()) {
|
|
locales.emplace_back(Locale::FromBCP47("en"));
|
|
}
|
|
TC3_ASSIGN_OR_RETURN(
|
|
const std::vector<EvaluatedDerivation> evaluated_derivations,
|
|
analyzer_.Parse(input, locales, &arena,
|
|
/*deduplicate_derivations=*/false));
|
|
|
|
std::vector<EvaluatedDerivation> valid_evaluated_derivations;
|
|
for (const EvaluatedDerivation& evaluated_derivation :
|
|
evaluated_derivations) {
|
|
if (evaluated_derivation.value) {
|
|
if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
|
|
const UngroundedDatetime* ungrounded_datetime =
|
|
evaluated_derivation.value->Table<UngroundedDatetime>();
|
|
if (datetime_grounder_.IsValidUngroundedDatetime(ungrounded_datetime)) {
|
|
valid_evaluated_derivations.emplace_back(evaluated_derivation);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
valid_evaluated_derivations =
|
|
grammar::DeduplicateDerivations(valid_evaluated_derivations);
|
|
for (const EvaluatedDerivation& evaluated_derivation :
|
|
valid_evaluated_derivations) {
|
|
if (evaluated_derivation.value) {
|
|
if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
|
|
const UngroundedDatetime* ungrounded_datetime =
|
|
evaluated_derivation.value->Table<UngroundedDatetime>();
|
|
if ((ungrounded_datetime->annotation_usecases() &
|
|
(1 << annotation_usecase)) == 0) {
|
|
continue;
|
|
}
|
|
const StatusOr<std::vector<DatetimeParseResult>>&
|
|
datetime_parse_results = datetime_grounder_.Ground(
|
|
reference_time_ms_utc, reference_timezone,
|
|
locale_list.GetReferenceLocale(), ungrounded_datetime);
|
|
TC3_ASSIGN_OR_RETURN(
|
|
const std::vector<DatetimeParseResult>& parse_datetime,
|
|
datetime_parse_results);
|
|
DatetimeParseResultSpan datetime_parse_result_span;
|
|
datetime_parse_result_span.target_classification_score =
|
|
target_classification_score_;
|
|
datetime_parse_result_span.priority_score = priority_score_;
|
|
datetime_parse_result_span.data.reserve(parse_datetime.size());
|
|
datetime_parse_result_span.data.insert(
|
|
datetime_parse_result_span.data.end(), parse_datetime.begin(),
|
|
parse_datetime.end());
|
|
datetime_parse_result_span.span =
|
|
evaluated_derivation.parse_tree->codepoint_span;
|
|
|
|
results.emplace_back(datetime_parse_result_span);
|
|
}
|
|
}
|
|
}
|
|
return results;
|
|
}
|
|
} // namespace libtextclassifier3
|