You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
148 lines
6.6 KiB
148 lines
6.6 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
|
|
#define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "annotator/model_generated.h"
|
|
#include "annotator/types.h"
|
|
#include "absl/strings/string_view.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
// Converts saft labels like /saft/person to collection name 'person'.
|
|
std::string SaftLabelToCollection(absl::string_view saft_label);
|
|
|
|
struct WordpieceSpan {
|
|
// Beginning index is inclusive, end index is exclusive.
|
|
WordpieceSpan() : begin(0), end(0) {}
|
|
WordpieceSpan(int begin, int end) : begin(begin), end(end) {}
|
|
int begin;
|
|
int end;
|
|
bool operator==(const WordpieceSpan &other) const {
|
|
return this->begin == other.begin && this->end == other.end;
|
|
}
|
|
int length() { return end - begin; }
|
|
};
|
|
|
|
namespace internal {
|
|
// Finds the wordpiece window arond the given span_of_interest. If the number
|
|
// of wordpieces in this window is smaller than max_num_wordpieces_in_window
|
|
// it is expanded around the span of interest.
|
|
WordpieceSpan FindWordpiecesWindowAroundSpan(
|
|
const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
|
|
const std::vector<int32_t> &word_starts, int num_wordpieces,
|
|
int max_num_wordpieces_in_window);
|
|
// Expands the given wordpiece window around the given window to the be
|
|
// maximal possible while making sure it includes only full tokens.
|
|
WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
|
|
int num_wordpieces,
|
|
WordpieceSpan wordpiece_span_to_expand);
|
|
// Returns the index of the last token which ends before wordpiece_end.
|
|
int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
|
|
int num_wordpieces, int wordpiece_end);
|
|
// Returns the index of the token which includes first_wordpiece_index.
|
|
int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
|
|
int first_wordpiece_index);
|
|
// Given wordpiece_span, and max_num_wordpieces, finds:
|
|
// 1. The first token which includes wordpiece_span.begin.
|
|
// 2. The length of tokens sequence which starts from this token and:
|
|
// a. Its last token's last wordpiece index ends before wordpiece_span.end.
|
|
// b. Its overall number of wordpieces is at most max_num_wordpieces.
|
|
// Returns the updated wordpiece_span: begin and end wordpieces of this token
|
|
// sequence.
|
|
WordpieceSpan FindFullTokensSpanInWindow(
|
|
const std::vector<int32_t> &word_starts,
|
|
const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
|
|
int num_wordpieces, int *first_token_index, int *num_tokens);
|
|
|
|
} // namespace internal
|
|
// Converts sequence of IOB tags to AnnotatedSpans. Ignores illegal sequences.
|
|
// Setting label_filter can also help ignore certain label tags like "NAM" or
|
|
// "NOM".
|
|
// The inside tag can be ignored when setting relaxed_inside_label_matching,
|
|
// e.g. B-NAM-location, I-NAM-other, E-NAM-location would be considered a valid
|
|
// sequence.
|
|
// The label category matching can be ignored when setting
|
|
// relaxed_label_category_matching. The matching will only operate at the entity
|
|
// level, e.g. B-NAM-location, E-NOM-location would be considered a valid
|
|
// sequence.
|
|
bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
|
|
const std::vector<std::string> &tags,
|
|
const std::vector<std::string> &label_filter,
|
|
bool relaxed_inside_label_matching,
|
|
bool relaxed_label_category_matching,
|
|
float priority_score,
|
|
std::vector<AnnotatedSpan> *results);
|
|
|
|
// Like the previous function but instead of getting the tags as strings
|
|
// the input is PodNerModel_::LabelT along with the collections vector which
|
|
// hold the collection name and priorities. e.g. a tag was "B-NAM-location" and
|
|
// the priority_score was 1.0 it would be Label(BoiseType_BEGIN,
|
|
// MentionType_NAM, 1) and collections={{"xxx", 1., 1.},
|
|
// {"location", 1., 1.}, {"yyy", 1., 1.}, ...}.
|
|
bool ConvertTagsToAnnotatedSpans(
|
|
const VectorSpan<Token> &tokens,
|
|
const std::vector<PodNerModel_::LabelT> &labels,
|
|
const std::vector<PodNerModel_::CollectionT> &collections,
|
|
const std::vector<PodNerModel_::Label_::MentionType> &mention_filter,
|
|
bool relaxed_inside_label_matching, bool relaxed_mention_type_matching,
|
|
std::vector<AnnotatedSpan> *results);
|
|
|
|
// Merge two overlaping sequences of labels, the result is placed into the left
|
|
// sequence. In the overlapping part takes the labels from the left sequence on
|
|
// the first half and from the right on the second half.
|
|
bool MergeLabelsIntoLeftSequence(
|
|
const std::vector<PodNerModel_::LabelT> &labels_right,
|
|
int index_first_right_tag_in_left,
|
|
std::vector<PodNerModel_::LabelT> *labels_left);
|
|
|
|
// This class is used to slide over {wordpiece_indices, token_starts, tokens} in
|
|
// windows of at most max_num_wordpieces while assuring that each window
|
|
// contains only full tokens.
|
|
class WindowGenerator {
|
|
public:
|
|
WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
|
|
const std::vector<int32_t> &token_starts,
|
|
const std::vector<Token> &tokens, int max_num_wordpieces,
|
|
int sliding_window_overlap,
|
|
const CodepointSpan &span_of_interest);
|
|
|
|
bool Next(VectorSpan<int32_t> *cur_wordpiece_indices,
|
|
VectorSpan<int32_t> *cur_token_starts,
|
|
VectorSpan<Token> *cur_tokens);
|
|
|
|
bool Done() const {
|
|
return previous_wordpiece_span_.end >= entire_wordpiece_span_.end;
|
|
}
|
|
|
|
private:
|
|
const std::vector<int32_t> *wordpiece_indices_;
|
|
const std::vector<int32_t> *token_starts_;
|
|
const std::vector<Token> *tokens_;
|
|
int max_num_effective_wordpieces_;
|
|
int sliding_window_num_wordpieces_overlap_;
|
|
WordpieceSpan entire_wordpiece_span_;
|
|
WordpieceSpan next_wordpiece_span_;
|
|
WordpieceSpan previous_wordpiece_span_;
|
|
};
|
|
|
|
} // namespace libtextclassifier3
|
|
|
|
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
|