/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ #define LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ #include #include #include #include "utils/wordpiece_tokenizer.h" #include "absl/container/flat_hash_map.h" #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" #include "tensorflow_lite_support/cc/utils/common_utils.h" namespace libtextclassifier3 { using ::tflite::support::text::tokenizer::TokenizerResult; using ::tflite::support::utils::LoadVocabFromBuffer; using ::tflite::support::utils::LoadVocabFromFile; constexpr int kDefaultMaxBytesPerToken = 100; constexpr int kDefaultMaxCharsPerSubToken = 100; constexpr char kDefaultSuffixIndicator[] = "##"; constexpr bool kDefaultUseUnknownToken = true; constexpr char kDefaultUnknownToken[] = "[UNK]"; constexpr bool kDefaultSplitUnknownChars = false; // Result of wordpiece tokenization including subwords and offsets. // Example: // input: tokenize me please // subwords: token ##ize me plea ##se // wp_begin_offset: [0, 5, 9, 12, 16] // wp_end_offset: [ 5, 8, 11, 16, 18] // row_lengths: [2, 1, 1] struct WordpieceTokenizerResult : tflite::support::text::tokenizer::TokenizerResult { std::vector wp_begin_offset; std::vector wp_end_offset; std::vector row_lengths; }; // Options to create a BertTokenizer. struct BertTokenizerOptions { int max_bytes_per_token = kDefaultMaxBytesPerToken; int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken; std::string suffix_indicator = kDefaultSuffixIndicator; bool use_unknown_token = kDefaultUseUnknownToken; std::string unknown_token = kDefaultUnknownToken; bool split_unknown_chars = kDefaultSplitUnknownChars; }; // A flat-hash-map based implementation of WordpieceVocab, used in // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within. class FlatHashMapBackedWordpiece : public WordpieceVocab { public: explicit FlatHashMapBackedWordpiece(const std::vector& vocab); LookupStatus Contains(absl::string_view key, bool* value) const override; bool LookupId(absl::string_view key, int* result) const; bool LookupWord(int vocab_id, absl::string_view* result) const; int VocabularySize() const { return vocab_.size(); } private: // All words indexed position in vocabulary file. std::vector vocab_; absl::flat_hash_map index_map_; }; // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector. class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer { public: // Initialize the tokenizer from vocab vector and tokenizer configs. explicit BertTokenizer(const std::vector& vocab, const BertTokenizerOptions& options = {}) : vocab_{FlatHashMapBackedWordpiece(vocab)}, options_{options} {} // Initialize the tokenizer from file path to vocab and tokenizer configs. explicit BertTokenizer(const std::string& path_to_vocab, const BertTokenizerOptions& options = {}) : BertTokenizer(LoadVocabFromFile(path_to_vocab), options) {} // Initialize the tokenizer from buffer and size of vocab and tokenizer // configs. BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size, const BertTokenizerOptions& options = {}) : BertTokenizer(LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size), options) {} // Perform tokenization, first tokenize the input and then find the subwords. // return tokenized results containing the subwords. TokenizerResult Tokenize(const std::string& input) override; // Perform tokenization on a single token, return tokenized results containing // the subwords. TokenizerResult TokenizeSingleToken(const std::string& token); // Perform tokenization, return tokenized results containing the subwords. TokenizerResult Tokenize(const std::vector& tokens); // Check if a certain key is included in the vocab. LookupStatus Contains(const absl::string_view key, bool* value) const { return vocab_.Contains(key, value); } // Find the id of a wordpiece. bool LookupId(absl::string_view key, int* result) const override { return vocab_.LookupId(key, result); } // Find the wordpiece from an id. bool LookupWord(int vocab_id, absl::string_view* result) const override { return vocab_.LookupWord(vocab_id, result); } int VocabularySize() const { return vocab_.VocabularySize(); } static std::vector PreTokenize(const absl::string_view input); private: FlatHashMapBackedWordpiece vocab_; BertTokenizerOptions options_; }; } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_