You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
320 lines
14 KiB
320 lines
14 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2018 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
// Feature processing for FFModel (feed-forward SmartSelection model).
|
||
|
|
||
|
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
|
||
|
#define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
|
||
|
|
||
|
#include <map>
|
||
|
#include <memory>
|
||
|
#include <set>
|
||
|
#include <string>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "annotator/cached-features.h"
|
||
|
#include "annotator/model_generated.h"
|
||
|
#include "annotator/types.h"
|
||
|
#include "utils/base/integral_types.h"
|
||
|
#include "utils/base/logging.h"
|
||
|
#include "utils/token-feature-extractor.h"
|
||
|
#include "utils/tokenizer.h"
|
||
|
#include "utils/utf8/unicodetext.h"
|
||
|
#include "utils/utf8/unilib.h"
|
||
|
|
||
|
namespace libtextclassifier3 {
|
||
|
|
||
|
constexpr int kInvalidLabel = -1;
|
||
|
|
||
|
namespace internal {
|
||
|
|
||
|
Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
|
||
|
const UniLib* unilib);
|
||
|
|
||
|
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
|
||
|
const FeatureProcessorOptions* options);
|
||
|
|
||
|
// Splits tokens that contain the selection boundary inside them.
|
||
|
// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
|
||
|
void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
|
||
|
std::vector<Token>* tokens);
|
||
|
|
||
|
// Returns the index of token that corresponds to the codepoint span.
|
||
|
int CenterTokenFromClick(const CodepointSpan& span,
|
||
|
const std::vector<Token>& tokens);
|
||
|
|
||
|
// Returns the index of token that corresponds to the middle of the codepoint
|
||
|
// span.
|
||
|
int CenterTokenFromMiddleOfSelection(
|
||
|
const CodepointSpan& span, const std::vector<Token>& selectable_tokens);
|
||
|
|
||
|
// Strips the tokens from the tokens vector that are not used for feature
|
||
|
// extraction because they are out of scope, or pads them so that there is
|
||
|
// enough tokens in the required context_size for all inferences with a click
|
||
|
// in relative_click_span.
|
||
|
void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
|
||
|
std::vector<Token>* tokens, int* click_pos);
|
||
|
|
||
|
} // namespace internal
|
||
|
|
||
|
// Converts a codepoint span to a token span in the given list of tokens.
|
||
|
// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
|
||
|
// token to overlap with the codepoint range to be considered part of it.
|
||
|
// Otherwise it must be fully included in the range.
|
||
|
TokenSpan CodepointSpanToTokenSpan(
|
||
|
const std::vector<Token>& selectable_tokens,
|
||
|
const CodepointSpan& codepoint_span,
|
||
|
bool snap_boundaries_to_containing_tokens = false);
|
||
|
|
||
|
// Converts a token span to a codepoint span in the given list of tokens.
|
||
|
CodepointSpan TokenSpanToCodepointSpan(
|
||
|
const std::vector<Token>& selectable_tokens, const TokenSpan& token_span);
|
||
|
|
||
|
// Converts a codepoint span to a unicode text range, within the given unicode
|
||
|
// text.
|
||
|
// For an invalid span (with a negative index), returns (begin, begin). This
|
||
|
// means that it is safe to call this function before checking the validity of
|
||
|
// the span.
|
||
|
// The indices must fit within the unicode text.
|
||
|
// Note that the execution time is linear with respect to the codepoint indices.
|
||
|
// Calling this function repeatedly for spans on the same text might lead to
|
||
|
// inefficient code.
|
||
|
UnicodeTextRange CodepointSpanToUnicodeTextRange(
|
||
|
const UnicodeText& unicode_text, const CodepointSpan& span);
|
||
|
|
||
|
// Takes care of preparing features for the span prediction model.
|
||
|
class FeatureProcessor {
|
||
|
public:
|
||
|
// A cache mapping codepoint spans to embedded tokens features. An instance
|
||
|
// can be provided to multiple calls to ExtractFeatures() operating on the
|
||
|
// same context (the same codepoint spans corresponding to the same tokens),
|
||
|
// as an optimization. Note that the tokenizations do not have to be
|
||
|
// identical.
|
||
|
typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
|
||
|
|
||
|
explicit FeatureProcessor(const FeatureProcessorOptions* options,
|
||
|
const UniLib* unilib)
|
||
|
: feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
|
||
|
unilib),
|
||
|
options_(options),
|
||
|
tokenizer_(internal::BuildTokenizer(options, unilib)) {
|
||
|
MakeLabelMaps();
|
||
|
if (options->supported_codepoint_ranges() != nullptr) {
|
||
|
SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
|
||
|
options->supported_codepoint_ranges()->end()},
|
||
|
&supported_codepoint_ranges_);
|
||
|
}
|
||
|
PrepareIgnoredSpanBoundaryCodepoints();
|
||
|
}
|
||
|
|
||
|
// Tokenizes the input string using the selected tokenization method.
|
||
|
std::vector<Token> Tokenize(const std::string& text) const;
|
||
|
|
||
|
// Same as above but takes UnicodeText.
|
||
|
std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
|
||
|
|
||
|
// Converts a label into a token span.
|
||
|
bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
|
||
|
|
||
|
// Gets the total number of selection labels.
|
||
|
int GetSelectionLabelCount() const { return label_to_selection_.size(); }
|
||
|
|
||
|
// Gets the string value for given collection label.
|
||
|
std::string LabelToCollection(int label) const;
|
||
|
|
||
|
// Gets the total number of collections of the model.
|
||
|
int NumCollections() const { return collection_to_label_.size(); }
|
||
|
|
||
|
// Gets the name of the default collection.
|
||
|
std::string GetDefaultCollection() const;
|
||
|
|
||
|
const FeatureProcessorOptions* GetOptions() const { return options_; }
|
||
|
|
||
|
// Retokenizes the context and input span, and finds the click position.
|
||
|
// Depending on the options, might modify tokens (split them or remove them).
|
||
|
void RetokenizeAndFindClick(const std::string& context,
|
||
|
const CodepointSpan& input_span,
|
||
|
bool only_use_line_with_click,
|
||
|
std::vector<Token>* tokens, int* click_pos) const;
|
||
|
|
||
|
// Same as above, but takes UnicodeText and iterators within it corresponding
|
||
|
// to input_span.
|
||
|
void RetokenizeAndFindClick(const UnicodeText& context_unicode,
|
||
|
const UnicodeText::const_iterator& span_begin,
|
||
|
const UnicodeText::const_iterator& span_end,
|
||
|
const CodepointSpan& input_span,
|
||
|
bool only_use_line_with_click,
|
||
|
std::vector<Token>* tokens, int* click_pos) const;
|
||
|
|
||
|
// Returns true if the token span has enough supported codepoints (as defined
|
||
|
// in the model config) or not and model should not run.
|
||
|
bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
|
||
|
const TokenSpan& token_span) const;
|
||
|
|
||
|
// Extracts features as a CachedFeatures object that can be used for repeated
|
||
|
// inference over token spans in the given context.
|
||
|
bool ExtractFeatures(const std::vector<Token>& tokens,
|
||
|
const TokenSpan& token_span,
|
||
|
const CodepointSpan& selection_span_for_feature,
|
||
|
const EmbeddingExecutor* embedding_executor,
|
||
|
EmbeddingCache* embedding_cache, int feature_vector_size,
|
||
|
std::unique_ptr<CachedFeatures>* cached_features) const;
|
||
|
|
||
|
// Fills selection_label_spans with CodepointSpans that correspond to the
|
||
|
// selection labels. The CodepointSpans are based on the codepoint ranges of
|
||
|
// given tokens.
|
||
|
bool SelectionLabelSpans(
|
||
|
VectorSpan<Token> tokens,
|
||
|
std::vector<CodepointSpan>* selection_label_spans) const;
|
||
|
|
||
|
// Fills selection_label_relative_token_spans with number of tokens left and
|
||
|
// right from the click.
|
||
|
bool SelectionLabelRelativeTokenSpans(
|
||
|
std::vector<TokenSpan>* selection_label_relative_token_spans) const;
|
||
|
|
||
|
int DenseFeaturesCount() const {
|
||
|
return feature_extractor_.DenseFeaturesCount();
|
||
|
}
|
||
|
|
||
|
int EmbeddingSize() const { return options_->embedding_size(); }
|
||
|
|
||
|
// Splits context to several segments.
|
||
|
std::vector<UnicodeTextRange> SplitContext(
|
||
|
const UnicodeText& context_unicode,
|
||
|
const bool use_pipe_character_for_newline) const;
|
||
|
|
||
|
// Strips boundary codepoints from the span in context and returns the new
|
||
|
// start and end indices. If the span comprises entirely of boundary
|
||
|
// codepoints, the first index of span is returned for both indices.
|
||
|
CodepointSpan StripBoundaryCodepoints(const std::string& context,
|
||
|
const CodepointSpan& span) const;
|
||
|
|
||
|
// Same as above but takes UnicodeText.
|
||
|
CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
|
||
|
const CodepointSpan& span) const;
|
||
|
|
||
|
// Same as above but takes a pair of iterators for the span, for efficiency.
|
||
|
CodepointSpan StripBoundaryCodepoints(
|
||
|
const UnicodeText::const_iterator& span_begin,
|
||
|
const UnicodeText::const_iterator& span_end,
|
||
|
const CodepointSpan& span) const;
|
||
|
|
||
|
// Same as above, but takes an optional buffer for saving the modified value.
|
||
|
// As an optimization, returns pointer to 'value' if nothing was stripped, or
|
||
|
// pointer to 'buffer' if something was stripped.
|
||
|
const std::string& StripBoundaryCodepoints(const std::string& value,
|
||
|
std::string* buffer) const;
|
||
|
|
||
|
protected:
|
||
|
// Returns the class id corresponding to the given string collection
|
||
|
// identifier. There is a catch-all class id that the function returns for
|
||
|
// unknown collections.
|
||
|
int CollectionToLabel(const std::string& collection) const;
|
||
|
|
||
|
// Prepares mapping from collection names to labels.
|
||
|
void MakeLabelMaps();
|
||
|
|
||
|
// Gets the number of spannable tokens for the model.
|
||
|
//
|
||
|
// Spannable tokens are those tokens of context, which the model predicts
|
||
|
// selection spans over (i.e., there is 1:1 correspondence between the output
|
||
|
// classes of the model and each of the spannable tokens).
|
||
|
int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
|
||
|
|
||
|
// Converts a label into a span of codepoint indices corresponding to it
|
||
|
// given output_tokens.
|
||
|
bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
|
||
|
CodepointSpan* span) const;
|
||
|
|
||
|
// Converts a span to the corresponding label given output_tokens.
|
||
|
bool SpanToLabel(const CodepointSpan& span,
|
||
|
const std::vector<Token>& output_tokens, int* label) const;
|
||
|
|
||
|
// Converts a token span to the corresponding label.
|
||
|
int TokenSpanToLabel(const TokenSpan& token_span) const;
|
||
|
|
||
|
// Returns the ratio of supported codepoints to total number of codepoints in
|
||
|
// the given token span.
|
||
|
float SupportedCodepointsRatio(const TokenSpan& token_span,
|
||
|
const std::vector<Token>& tokens) const;
|
||
|
|
||
|
void PrepareIgnoredSpanBoundaryCodepoints();
|
||
|
|
||
|
// Counts the number of span boundary codepoints. If count_from_beginning is
|
||
|
// True, the counting will start at the span_start iterator (inclusive) and at
|
||
|
// maximum end at span_end (exclusive). If count_from_beginning is True, the
|
||
|
// counting will start from span_end (exclusive) and end at span_start
|
||
|
// (inclusive).
|
||
|
int CountIgnoredSpanBoundaryCodepoints(
|
||
|
const UnicodeText::const_iterator& span_start,
|
||
|
const UnicodeText::const_iterator& span_end,
|
||
|
bool count_from_beginning) const;
|
||
|
|
||
|
// Finds the center token index in tokens vector, using the method defined
|
||
|
// in options_.
|
||
|
int FindCenterToken(const CodepointSpan& span,
|
||
|
const std::vector<Token>& tokens) const;
|
||
|
|
||
|
// Removes all tokens from tokens that are not on a line (defined by calling
|
||
|
// SplitContext on the context) to which span points.
|
||
|
void StripTokensFromOtherLines(const std::string& context,
|
||
|
const CodepointSpan& span,
|
||
|
std::vector<Token>* tokens) const;
|
||
|
|
||
|
// Same as above but takes UnicodeText.
|
||
|
void StripTokensFromOtherLines(const UnicodeText& context_unicode,
|
||
|
const UnicodeText::const_iterator& span_begin,
|
||
|
const UnicodeText::const_iterator& span_end,
|
||
|
const CodepointSpan& span,
|
||
|
std::vector<Token>* tokens) const;
|
||
|
|
||
|
// Extracts the features of a token and appends them to the output vector.
|
||
|
// Uses the embedding cache to to avoid re-extracting the re-embedding the
|
||
|
// sparse features for the same token.
|
||
|
bool AppendTokenFeaturesWithCache(
|
||
|
const Token& token, const CodepointSpan& selection_span_for_feature,
|
||
|
const EmbeddingExecutor* embedding_executor,
|
||
|
EmbeddingCache* embedding_cache,
|
||
|
std::vector<float>* output_features) const;
|
||
|
|
||
|
protected:
|
||
|
const TokenFeatureExtractor feature_extractor_;
|
||
|
|
||
|
// Codepoint ranges that define what codepoints are supported by the model.
|
||
|
// NOTE: Must be sorted.
|
||
|
std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
|
||
|
|
||
|
private:
|
||
|
// Set of codepoints that will be stripped from beginning and end of
|
||
|
// predicted spans.
|
||
|
std::unordered_set<int32> ignored_span_boundary_codepoints_;
|
||
|
|
||
|
const FeatureProcessorOptions* const options_;
|
||
|
|
||
|
// Mapping between token selection spans and labels ids.
|
||
|
std::map<TokenSpan, int> selection_to_label_;
|
||
|
std::vector<TokenSpan> label_to_selection_;
|
||
|
|
||
|
// Mapping between collections and labels.
|
||
|
std::map<std::string, int> collection_to_label_;
|
||
|
|
||
|
Tokenizer tokenizer_;
|
||
|
};
|
||
|
|
||
|
} // namespace libtextclassifier3
|
||
|
|
||
|
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
|