You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
437 lines
15 KiB
437 lines
15 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "annotator/pod_ner/utils.h"
|
|
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include <unordered_map>
|
|
|
|
#include "annotator/model_generated.h"
|
|
#include "annotator/types.h"
|
|
#include "utils/base/logging.h"
|
|
#include "absl/strings/str_cat.h"
|
|
#include "absl/strings/str_split.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
|
|
// Returns true if the needle string is contained in the haystack.
|
|
bool StrIsOneOf(const std::string &needle,
|
|
const std::vector<std::string> &haystack) {
|
|
return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
|
|
}
|
|
|
|
// Finds the wordpiece span of the tokens in the given span.
|
|
WordpieceSpan CodepointSpanToWordpieceSpan(
|
|
const CodepointSpan &span, const std::vector<Token> &tokens,
|
|
const std::vector<int32_t> &word_starts, int num_wordpieces) {
|
|
int span_first_wordpiece_index = 0;
|
|
int span_last_wordpiece_index = num_wordpieces;
|
|
for (int i = 0; i < tokens.size(); i++) {
|
|
if (tokens[i].start <= span.first && span.first < tokens[i].end) {
|
|
span_first_wordpiece_index = word_starts[i];
|
|
}
|
|
if (tokens[i].start <= span.second && span.second <= tokens[i].end) {
|
|
span_last_wordpiece_index =
|
|
(i + 1) < word_starts.size() ? word_starts[i + 1] : num_wordpieces;
|
|
break;
|
|
}
|
|
}
|
|
return WordpieceSpan(span_first_wordpiece_index, span_last_wordpiece_index);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::string SaftLabelToCollection(absl::string_view saft_label) {
|
|
return std::string(saft_label.substr(saft_label.rfind('/') + 1));
|
|
}
|
|
|
|
namespace internal {
|
|
|
|
int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
|
|
int num_wordpieces, int wordpiece_end) {
|
|
if (word_starts.empty()) {
|
|
return 0;
|
|
}
|
|
if (*word_starts.rbegin() < wordpiece_end &&
|
|
num_wordpieces <= wordpiece_end) {
|
|
// Last token.
|
|
return word_starts.size() - 1;
|
|
}
|
|
for (int i = word_starts.size() - 1; i > 0; --i) {
|
|
if (word_starts[i] <= wordpiece_end) {
|
|
return (i - 1);
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
|
|
int first_wordpiece_index) {
|
|
for (int i = 0; i < word_starts.size(); ++i) {
|
|
if (word_starts[i] == first_wordpiece_index) {
|
|
return i;
|
|
} else if (word_starts[i] > first_wordpiece_index) {
|
|
return std::max(0, i - 1);
|
|
}
|
|
}
|
|
|
|
return std::max(0, static_cast<int>(word_starts.size()) - 1);
|
|
}
|
|
|
|
WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
|
|
int num_wordpieces,
|
|
WordpieceSpan wordpiece_span_to_expand) {
|
|
if (wordpiece_span_to_expand.length() >= max_num_wordpieces_in_window) {
|
|
return wordpiece_span_to_expand;
|
|
}
|
|
int window_first_wordpiece_index = std::max(
|
|
0, wordpiece_span_to_expand.begin - ((max_num_wordpieces_in_window -
|
|
wordpiece_span_to_expand.length()) /
|
|
2));
|
|
if ((window_first_wordpiece_index + max_num_wordpieces_in_window) >
|
|
num_wordpieces) {
|
|
window_first_wordpiece_index =
|
|
std::max(num_wordpieces - max_num_wordpieces_in_window, 0);
|
|
}
|
|
return WordpieceSpan(
|
|
window_first_wordpiece_index,
|
|
std::min(window_first_wordpiece_index + max_num_wordpieces_in_window,
|
|
num_wordpieces));
|
|
}
|
|
|
|
WordpieceSpan FindWordpiecesWindowAroundSpan(
|
|
const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
|
|
const std::vector<int32_t> &word_starts, int num_wordpieces,
|
|
int max_num_wordpieces_in_window) {
|
|
WordpieceSpan wordpiece_span_to_expand = CodepointSpanToWordpieceSpan(
|
|
span_of_interest, tokens, word_starts, num_wordpieces);
|
|
WordpieceSpan max_wordpiece_span = ExpandWindowAndAlign(
|
|
max_num_wordpieces_in_window, num_wordpieces, wordpiece_span_to_expand);
|
|
return max_wordpiece_span;
|
|
}
|
|
|
|
WordpieceSpan FindFullTokensSpanInWindow(
|
|
const std::vector<int32_t> &word_starts,
|
|
const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
|
|
int num_wordpieces, int *first_token_index, int *num_tokens) {
|
|
int window_first_wordpiece_index = wordpiece_span.begin;
|
|
*first_token_index = internal::FindFirstFullTokenIndex(
|
|
word_starts, window_first_wordpiece_index);
|
|
window_first_wordpiece_index = word_starts[*first_token_index];
|
|
|
|
// Need to update the last index in case the first moved backward.
|
|
int wordpiece_window_end = std::min(
|
|
wordpiece_span.end, window_first_wordpiece_index + max_num_wordpieces);
|
|
int last_token_index;
|
|
last_token_index = internal::FindLastFullTokenIndex(
|
|
word_starts, num_wordpieces, wordpiece_window_end);
|
|
wordpiece_window_end = last_token_index == (word_starts.size() - 1)
|
|
? num_wordpieces
|
|
: word_starts[last_token_index + 1];
|
|
|
|
*num_tokens = last_token_index - *first_token_index + 1;
|
|
return WordpieceSpan(window_first_wordpiece_index, wordpiece_window_end);
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
WindowGenerator::WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
|
|
const std::vector<int32_t> &token_starts,
|
|
const std::vector<Token> &tokens,
|
|
int max_num_wordpieces,
|
|
int sliding_window_overlap,
|
|
const CodepointSpan &span_of_interest)
|
|
: wordpiece_indices_(&wordpiece_indices),
|
|
token_starts_(&token_starts),
|
|
tokens_(&tokens),
|
|
max_num_effective_wordpieces_(max_num_wordpieces),
|
|
sliding_window_num_wordpieces_overlap_(sliding_window_overlap) {
|
|
entire_wordpiece_span_ = internal::FindWordpiecesWindowAroundSpan(
|
|
span_of_interest, tokens, token_starts, wordpiece_indices.size(),
|
|
max_num_wordpieces);
|
|
next_wordpiece_span_ = WordpieceSpan(
|
|
entire_wordpiece_span_.begin,
|
|
std::min(entire_wordpiece_span_.begin + max_num_effective_wordpieces_,
|
|
entire_wordpiece_span_.end));
|
|
previous_wordpiece_span_ = WordpieceSpan(-1, -1);
|
|
}
|
|
|
|
bool WindowGenerator::Next(VectorSpan<int32_t> *cur_wordpiece_indices,
|
|
VectorSpan<int32_t> *cur_token_starts,
|
|
VectorSpan<Token> *cur_tokens) {
|
|
if (Done()) {
|
|
return false;
|
|
}
|
|
// Update the span to cover full tokens.
|
|
int cur_first_token_index, cur_num_tokens;
|
|
next_wordpiece_span_ = internal::FindFullTokensSpanInWindow(
|
|
*token_starts_, next_wordpiece_span_, max_num_effective_wordpieces_,
|
|
wordpiece_indices_->size(), &cur_first_token_index, &cur_num_tokens);
|
|
*cur_token_starts = VectorSpan<int32_t>(
|
|
token_starts_->begin() + cur_first_token_index,
|
|
token_starts_->begin() + cur_first_token_index + cur_num_tokens);
|
|
*cur_tokens = VectorSpan<Token>(
|
|
tokens_->begin() + cur_first_token_index,
|
|
tokens_->begin() + cur_first_token_index + cur_num_tokens);
|
|
|
|
// Handle the edge case where the tokens are composed of many wordpieces and
|
|
// the window doesn't advance.
|
|
if (next_wordpiece_span_.begin <= previous_wordpiece_span_.begin ||
|
|
next_wordpiece_span_.end <= previous_wordpiece_span_.end) {
|
|
return false;
|
|
}
|
|
previous_wordpiece_span_ = next_wordpiece_span_;
|
|
|
|
int next_wordpiece_first = std::max(
|
|
previous_wordpiece_span_.end - sliding_window_num_wordpieces_overlap_,
|
|
previous_wordpiece_span_.begin + 1);
|
|
next_wordpiece_span_ = WordpieceSpan(
|
|
next_wordpiece_first,
|
|
std::min(next_wordpiece_first + max_num_effective_wordpieces_,
|
|
entire_wordpiece_span_.end));
|
|
|
|
*cur_wordpiece_indices = VectorSpan<int>(
|
|
wordpiece_indices_->begin() + previous_wordpiece_span_.begin,
|
|
wordpiece_indices_->begin() + previous_wordpiece_span_.begin +
|
|
previous_wordpiece_span_.length());
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
|
|
const std::vector<std::string> &tags,
|
|
const std::vector<std::string> &label_filter,
|
|
bool relaxed_inside_label_matching,
|
|
bool relaxed_label_category_matching,
|
|
float priority_score,
|
|
std::vector<AnnotatedSpan> *results) {
|
|
AnnotatedSpan current_span;
|
|
std::string current_tag_type;
|
|
if (tags.size() > tokens.size()) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < tags.size(); i++) {
|
|
if (tags[i].empty()) {
|
|
return false;
|
|
}
|
|
|
|
std::vector<absl::string_view> tag_parts = absl::StrSplit(tags[i], '-');
|
|
TC3_CHECK_GT(tag_parts.size(), 0);
|
|
if (tag_parts[0].size() != 1) {
|
|
return false;
|
|
}
|
|
|
|
std::string tag_type = "";
|
|
if (tag_parts.size() > 2) {
|
|
// Skip if the current label doesn't match the filter.
|
|
if (!StrIsOneOf(std::string(tag_parts[1]), label_filter)) {
|
|
current_tag_type = "";
|
|
current_span = {};
|
|
continue;
|
|
}
|
|
|
|
// Relax the matching of the label category if specified.
|
|
tag_type = relaxed_label_category_matching
|
|
? std::string(tag_parts[2])
|
|
: absl::StrCat(tag_parts[1], "-", tag_parts[2]);
|
|
}
|
|
|
|
switch (tag_parts[0][0]) {
|
|
case 'S': {
|
|
if (tag_parts.size() != 3) {
|
|
return false;
|
|
}
|
|
|
|
current_span = {};
|
|
current_tag_type = "";
|
|
results->push_back(AnnotatedSpan{
|
|
{tokens[i].start, tokens[i].end},
|
|
{{/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
|
|
/*arg_score=*/1.0, priority_score}}});
|
|
break;
|
|
};
|
|
|
|
case 'B': {
|
|
if (tag_parts.size() != 3) {
|
|
return false;
|
|
}
|
|
current_tag_type = tag_type;
|
|
current_span = {};
|
|
current_span.classification.push_back(
|
|
{/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
|
|
/*arg_score=*/1.0, priority_score});
|
|
current_span.span.first = tokens[i].start;
|
|
break;
|
|
};
|
|
|
|
case 'I': {
|
|
if (tag_parts.size() != 3) {
|
|
return false;
|
|
}
|
|
if (!relaxed_inside_label_matching && current_tag_type != tag_type) {
|
|
current_tag_type = "";
|
|
current_span = {};
|
|
}
|
|
break;
|
|
}
|
|
|
|
case 'E': {
|
|
if (tag_parts.size() != 3) {
|
|
return false;
|
|
}
|
|
if (!current_tag_type.empty() && current_tag_type == tag_type) {
|
|
current_span.span.second = tokens[i].end;
|
|
results->push_back(current_span);
|
|
current_span = {};
|
|
current_tag_type = "";
|
|
}
|
|
break;
|
|
};
|
|
|
|
case 'O': {
|
|
current_tag_type = "";
|
|
current_span = {};
|
|
break;
|
|
};
|
|
|
|
default: {
|
|
TC3_LOG(ERROR) << "Unrecognized tag: " << tags[i];
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
using PodNerModel_::CollectionT;
|
|
using PodNerModel_::LabelT;
|
|
using PodNerModel_::Label_::BoiseType;
|
|
using PodNerModel_::Label_::MentionType;
|
|
|
|
bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
|
|
const std::vector<LabelT> &labels,
|
|
const std::vector<CollectionT> &collections,
|
|
const std::vector<MentionType> &mention_filter,
|
|
bool relaxed_inside_label_matching,
|
|
bool relaxed_mention_type_matching,
|
|
std::vector<AnnotatedSpan> *results) {
|
|
if (labels.size() > tokens.size()) {
|
|
return false;
|
|
}
|
|
|
|
AnnotatedSpan current_span;
|
|
std::string current_collection_name = "";
|
|
|
|
for (int i = 0; i < labels.size(); i++) {
|
|
const LabelT &label = labels[i];
|
|
|
|
if (label.collection_id < 0 || label.collection_id >= collections.size()) {
|
|
return false;
|
|
}
|
|
|
|
if (std::find(mention_filter.begin(), mention_filter.end(),
|
|
label.mention_type) == mention_filter.end()) {
|
|
// Skip if the current label doesn't match the filter.
|
|
current_span = {};
|
|
current_collection_name = "";
|
|
continue;
|
|
}
|
|
|
|
switch (label.boise_type) {
|
|
case BoiseType::BoiseType_SINGLE: {
|
|
current_span = {};
|
|
current_collection_name = "";
|
|
results->push_back(AnnotatedSpan{
|
|
{tokens[i].start, tokens[i].end},
|
|
{{/*arg_collection=*/collections[label.collection_id].name,
|
|
/*arg_score=*/1.0,
|
|
collections[label.collection_id].single_token_priority_score}}});
|
|
break;
|
|
};
|
|
|
|
case BoiseType::BoiseType_BEGIN: {
|
|
current_span = {};
|
|
current_span.classification.push_back(
|
|
{/*arg_collection=*/collections[label.collection_id].name,
|
|
/*arg_score=*/1.0,
|
|
collections[label.collection_id].multi_token_priority_score});
|
|
current_span.span.first = tokens[i].start;
|
|
current_collection_name = collections[label.collection_id].name;
|
|
break;
|
|
};
|
|
|
|
case BoiseType::BoiseType_INTERMEDIATE: {
|
|
if (current_collection_name.empty() ||
|
|
(!relaxed_mention_type_matching &&
|
|
labels[i - 1].mention_type != label.mention_type) ||
|
|
(!relaxed_inside_label_matching &&
|
|
labels[i - 1].collection_id != label.collection_id)) {
|
|
current_span = {};
|
|
current_collection_name = "";
|
|
}
|
|
break;
|
|
}
|
|
|
|
case BoiseType::BoiseType_END: {
|
|
if (!current_collection_name.empty() &&
|
|
current_collection_name == collections[label.collection_id].name &&
|
|
(relaxed_mention_type_matching ||
|
|
labels[i - 1].mention_type == label.mention_type)) {
|
|
current_span.span.second = tokens[i].end;
|
|
results->push_back(current_span);
|
|
}
|
|
current_span = {};
|
|
current_collection_name = "";
|
|
break;
|
|
};
|
|
|
|
case BoiseType::BoiseType_O: {
|
|
current_span = {};
|
|
current_collection_name = "";
|
|
break;
|
|
};
|
|
|
|
default: {
|
|
TC3_LOG(ERROR) << "Unrecognized tag: " << labels[i].boise_type;
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool MergeLabelsIntoLeftSequence(
|
|
const std::vector<PodNerModel_::LabelT> &labels_right,
|
|
int index_first_right_tag_in_left,
|
|
std::vector<PodNerModel_::LabelT> *labels_left) {
|
|
if (index_first_right_tag_in_left > labels_left->size()) {
|
|
return false;
|
|
}
|
|
|
|
int overlaping_from_left =
|
|
(labels_left->size() - index_first_right_tag_in_left) / 2;
|
|
|
|
labels_left->resize(index_first_right_tag_in_left + labels_right.size());
|
|
std::copy(labels_right.begin() + overlaping_from_left, labels_right.end(),
|
|
labels_left->begin() + index_first_right_tag_in_left +
|
|
overlaping_from_left);
|
|
return true;
|
|
}
|
|
|
|
} // namespace libtextclassifier3
|