You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
4.9 KiB
131 lines
4.9 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "annotator/vocab/vocab-annotator-impl.h"
|
|
|
|
#include "annotator/feature-processor.h"
|
|
#include "annotator/model_generated.h"
|
|
#include "utils/base/logging.h"
|
|
#include "utils/optional.h"
|
|
#include "utils/strings/numbers.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
|
|
VocabAnnotator::VocabAnnotator(
|
|
std::unique_ptr<VocabLevelTable> vocab_level_table,
|
|
const std::vector<Locale>& triggering_locales,
|
|
const FeatureProcessor& feature_processor, const UniLib& unilib,
|
|
const VocabModel* model)
|
|
: vocab_level_table_(std::move(vocab_level_table)),
|
|
triggering_locales_(triggering_locales),
|
|
feature_processor_(feature_processor),
|
|
unilib_(unilib),
|
|
model_(model) {}
|
|
|
|
std::unique_ptr<VocabAnnotator> VocabAnnotator::Create(
|
|
const VocabModel* model, const FeatureProcessor& feature_processor,
|
|
const UniLib& unilib) {
|
|
std::unique_ptr<VocabLevelTable> vocab_lebel_table =
|
|
VocabLevelTable::Create(model);
|
|
if (vocab_lebel_table == nullptr) {
|
|
TC3_LOG(ERROR) << "Failed to create vocab level table.";
|
|
return nullptr;
|
|
}
|
|
std::vector<Locale> triggering_locales;
|
|
if (model->triggering_locales() &&
|
|
!ParseLocales(model->triggering_locales()->c_str(),
|
|
&triggering_locales)) {
|
|
TC3_LOG(ERROR) << "Could not parse model supported locales.";
|
|
return nullptr;
|
|
}
|
|
|
|
return std::unique_ptr<VocabAnnotator>(
|
|
new VocabAnnotator(std::move(vocab_lebel_table), triggering_locales,
|
|
feature_processor, unilib, model));
|
|
}
|
|
|
|
bool VocabAnnotator::Annotate(
|
|
const UnicodeText& context,
|
|
const std::vector<Locale> detected_text_language_tags,
|
|
bool trigger_on_beginner_words, std::vector<AnnotatedSpan>* results) const {
|
|
std::vector<Token> tokens = feature_processor_.Tokenize(context);
|
|
for (const Token& token : tokens) {
|
|
ClassificationResult classification_result;
|
|
CodepointSpan stripped_span;
|
|
bool found = ClassifyTextInternal(
|
|
context, {token.start, token.end}, detected_text_language_tags,
|
|
trigger_on_beginner_words, &classification_result, &stripped_span);
|
|
if (found) {
|
|
results->push_back(AnnotatedSpan{stripped_span, {classification_result}});
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool VocabAnnotator::ClassifyText(
|
|
const UnicodeText& context, CodepointSpan click,
|
|
const std::vector<Locale> detected_text_language_tags,
|
|
bool trigger_on_beginner_words, ClassificationResult* result) const {
|
|
CodepointSpan stripped_span;
|
|
return ClassifyTextInternal(context, click, detected_text_language_tags,
|
|
trigger_on_beginner_words, result,
|
|
&stripped_span);
|
|
}
|
|
|
|
bool VocabAnnotator::ClassifyTextInternal(
|
|
const UnicodeText& context, const CodepointSpan click,
|
|
const std::vector<Locale> detected_text_language_tags,
|
|
bool trigger_on_beginner_words, ClassificationResult* classification_result,
|
|
CodepointSpan* classified_span) const {
|
|
if (vocab_level_table_ == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
|
|
triggering_locales_,
|
|
/*default_value=*/false)) {
|
|
return false;
|
|
}
|
|
const CodepointSpan stripped_span =
|
|
feature_processor_.StripBoundaryCodepoints(context,
|
|
{click.first, click.second});
|
|
const UnicodeText stripped_token = UnicodeText::Substring(
|
|
context, stripped_span.first, stripped_span.second, /*do_copy=*/false);
|
|
const std::string lower_token =
|
|
unilib_.ToLowerText(stripped_token).ToUTF8String();
|
|
|
|
const Optional<LookupResult> result = vocab_level_table_->Lookup(lower_token);
|
|
if (!result.has_value()) {
|
|
return false;
|
|
}
|
|
if (result.value().do_not_trigger_in_upper_case &&
|
|
unilib_.IsUpper(*stripped_token.begin())) {
|
|
TC3_VLOG(INFO) << "Not trigger define: proper noun in upper case.";
|
|
return false;
|
|
}
|
|
if (result.value().beginner_level && !trigger_on_beginner_words) {
|
|
TC3_VLOG(INFO) << "Not trigger define: for beginner only.";
|
|
return false;
|
|
}
|
|
*classification_result =
|
|
ClassificationResult("dictionary", model_->target_classification_score(),
|
|
model_->priority_score());
|
|
*classified_span = stripped_span;
|
|
|
|
return true;
|
|
}
|
|
} // namespace libtextclassifier3
|