You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
580 lines
19 KiB
580 lines
19 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "utils/tflite/string_projection.h"
|
|
|
|
#include <string>
|
|
#include <unordered_map>
|
|
|
|
#include "utils/strings/utf8.h"
|
|
#include "utils/tflite/string_projection_base.h"
|
|
#include "utils/utf8/unilib-common.h"
|
|
#include "absl/container/flat_hash_set.h"
|
|
#include "absl/strings/match.h"
|
|
#include "flatbuffers/flexbuffers.h"
|
|
#include "tensorflow/lite/context.h"
|
|
#include "tensorflow/lite/string_util.h"
|
|
|
|
namespace tflite {
|
|
namespace ops {
|
|
namespace custom {
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace string_projection {
|
|
namespace {
|
|
|
|
const char kStartToken[] = "<S>";
|
|
const char kEndToken[] = "<E>";
|
|
const char kEmptyToken[] = "<S> <E>";
|
|
constexpr size_t kEntireString = SIZE_MAX;
|
|
constexpr size_t kAllTokens = SIZE_MAX;
|
|
constexpr int kInvalid = -1;
|
|
|
|
constexpr char kApostrophe = '\'';
|
|
constexpr char kSpace = ' ';
|
|
constexpr char kComma = ',';
|
|
constexpr char kDot = '.';
|
|
|
|
// Returns true if the given text contains a number.
|
|
bool IsDigitString(const std::string& text) {
|
|
for (size_t i = 0; i < text.length();) {
|
|
const int bytes_read =
|
|
::libtextclassifier3::GetNumBytesForUTF8Char(text.data());
|
|
if (bytes_read <= 0 || bytes_read > text.length() - i) {
|
|
break;
|
|
}
|
|
const char32_t rune = ::libtextclassifier3::ValidCharToRune(text.data());
|
|
if (::libtextclassifier3::IsDigit(rune)) return true;
|
|
i += bytes_read;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Gets the string containing |num_chars| characters from |start| position.
|
|
std::string GetCharToken(const std::vector<std::string>& char_tokens, int start,
|
|
int num_chars) {
|
|
std::string char_token = "";
|
|
if (start + num_chars <= char_tokens.size()) {
|
|
for (int i = 0; i < num_chars; ++i) {
|
|
char_token.append(char_tokens[start + i]);
|
|
}
|
|
}
|
|
return char_token;
|
|
}
|
|
|
|
// Counts how many times |pattern| appeared from |start| position.
|
|
int GetNumPattern(const std::vector<std::string>& char_tokens, size_t start,
|
|
size_t num_chars, const std::string& pattern) {
|
|
int count = 0;
|
|
for (int i = start; i < char_tokens.size(); i += num_chars) {
|
|
std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
|
|
if (pattern == cur_pattern) {
|
|
++count;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
inline size_t FindNextSpace(const char* input_ptr, size_t from, size_t length) {
|
|
size_t space_index;
|
|
for (space_index = from; space_index < length; space_index++) {
|
|
if (input_ptr[space_index] == kSpace) {
|
|
break;
|
|
}
|
|
}
|
|
return space_index == length ? kInvalid : space_index;
|
|
}
|
|
|
|
template <typename T>
|
|
void SplitByCharInternal(std::vector<T>* tokens, const char* input_ptr,
|
|
size_t len, size_t max_tokens) {
|
|
for (size_t i = 0; i < len;) {
|
|
auto bytes_read =
|
|
::libtextclassifier3::GetNumBytesForUTF8Char(input_ptr + i);
|
|
if (bytes_read <= 0 || bytes_read > len - i) break;
|
|
tokens->emplace_back(input_ptr + i, bytes_read);
|
|
if (max_tokens != kInvalid && tokens->size() == max_tokens) {
|
|
break;
|
|
}
|
|
i += bytes_read;
|
|
}
|
|
}
|
|
|
|
std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
|
|
size_t max_tokens) {
|
|
std::vector<std::string> tokens;
|
|
SplitByCharInternal(&tokens, input_ptr, len, max_tokens);
|
|
return tokens;
|
|
}
|
|
|
|
std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
|
|
// This function contracts patterns whose length is |num_chars| and appeared
|
|
// more than twice. So if the input is shorter than 3 * |num_chars|, do not
|
|
// apply any contraction.
|
|
if (len < 3 * num_chars) {
|
|
return input_ptr;
|
|
}
|
|
std::vector<std::string> char_tokens = SplitByChar(input_ptr, len, len);
|
|
|
|
std::string token;
|
|
token.reserve(len);
|
|
for (int i = 0; i < char_tokens.size();) {
|
|
std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
|
|
|
|
// Count how many times this pattern appeared.
|
|
int num_cur_patterns = 0;
|
|
if (!absl::StrContains(cur_pattern, " ") && !IsDigitString(cur_pattern)) {
|
|
num_cur_patterns =
|
|
GetNumPattern(char_tokens, i + num_chars, num_chars, cur_pattern);
|
|
}
|
|
|
|
if (num_cur_patterns >= 2) {
|
|
// If this pattern is repeated, store it only twice.
|
|
token.append(cur_pattern);
|
|
token.append(cur_pattern);
|
|
i += (num_cur_patterns + 1) * num_chars;
|
|
} else {
|
|
token.append(char_tokens[i]);
|
|
++i;
|
|
}
|
|
}
|
|
|
|
return token;
|
|
}
|
|
|
|
template <typename T>
|
|
void SplitBySpaceInternal(std::vector<T>* tokens, const char* input_ptr,
|
|
size_t len, size_t max_input, size_t max_tokens) {
|
|
size_t last_index =
|
|
max_input == kEntireString ? len : (len < max_input ? len : max_input);
|
|
size_t start = 0;
|
|
// skip leading spaces
|
|
while (start < last_index && input_ptr[start] == kSpace) {
|
|
start++;
|
|
}
|
|
auto end = FindNextSpace(input_ptr, start, last_index);
|
|
while (end != kInvalid &&
|
|
(max_tokens == kAllTokens || tokens->size() < max_tokens - 1)) {
|
|
auto length = end - start;
|
|
if (length > 0) {
|
|
tokens->emplace_back(input_ptr + start, length);
|
|
}
|
|
|
|
start = end + 1;
|
|
end = FindNextSpace(input_ptr, start, last_index);
|
|
}
|
|
auto length = end == kInvalid ? (last_index - start) : (end - start);
|
|
if (length > 0) {
|
|
tokens->emplace_back(input_ptr + start, length);
|
|
}
|
|
}
|
|
|
|
std::vector<std::string> SplitBySpace(const char* input_ptr, size_t len,
|
|
size_t max_input, size_t max_tokens) {
|
|
std::vector<std::string> tokens;
|
|
SplitBySpaceInternal(&tokens, input_ptr, len, max_input, max_tokens);
|
|
return tokens;
|
|
}
|
|
|
|
bool prepend_separator(char separator) { return separator == kApostrophe; }
|
|
|
|
bool is_numeric(char c) { return c >= '0' && c <= '9'; }
|
|
|
|
class ProjectionNormalizer {
|
|
public:
|
|
explicit ProjectionNormalizer(const std::string& separators,
|
|
bool normalize_repetition = false) {
|
|
InitializeSeparators(separators);
|
|
normalize_repetition_ = normalize_repetition;
|
|
}
|
|
|
|
// Normalizes the repeated characters (except numbers) which consecutively
|
|
// appeared more than twice in a word.
|
|
std::string Normalize(const std::string& input, size_t max_input = 300) {
|
|
return Normalize(input.data(), input.length(), max_input);
|
|
}
|
|
std::string Normalize(const char* input_ptr, size_t len,
|
|
size_t max_input = 300) {
|
|
std::string normalized(input_ptr, std::min(len, max_input));
|
|
|
|
if (normalize_repetition_) {
|
|
// Remove repeated 1 char (e.g. soooo => soo)
|
|
normalized = ContractToken(normalized.data(), normalized.length(), 1);
|
|
|
|
// Remove repeated 2 chars from the beginning (e.g. hahaha =>
|
|
// haha, xhahaha => xhaha, xyhahaha => xyhaha).
|
|
normalized = ContractToken(normalized.data(), normalized.length(), 2);
|
|
|
|
// Remove repeated 3 chars from the beginning
|
|
// (e.g. wowwowwow => wowwow, abcdbcdbcd => abcdbcd).
|
|
normalized = ContractToken(normalized.data(), normalized.length(), 3);
|
|
}
|
|
|
|
if (!separators_.empty()) {
|
|
// Add space around separators_.
|
|
normalized = NormalizeInternal(normalized.data(), normalized.length());
|
|
}
|
|
return normalized;
|
|
}
|
|
|
|
private:
|
|
// Parses and extracts supported separators.
|
|
void InitializeSeparators(const std::string& separators) {
|
|
for (int i = 0; i < separators.length(); ++i) {
|
|
if (separators[i] != ' ') {
|
|
separators_.insert(separators[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Removes repeated chars.
|
|
std::string NormalizeInternal(const char* input_ptr, size_t len) {
|
|
std::string normalized;
|
|
normalized.reserve(len * 2);
|
|
for (int i = 0; i < len; ++i) {
|
|
char c = input_ptr[i];
|
|
bool matched_separator = separators_.find(c) != separators_.end();
|
|
if (matched_separator) {
|
|
if (i > 0 && input_ptr[i - 1] != ' ' && normalized.back() != ' ') {
|
|
normalized.append(" ");
|
|
}
|
|
}
|
|
normalized.append(1, c);
|
|
if (matched_separator) {
|
|
if (i + 1 < len && input_ptr[i + 1] != ' ' && c != '\'') {
|
|
normalized.append(" ");
|
|
}
|
|
}
|
|
}
|
|
return normalized;
|
|
}
|
|
|
|
absl::flat_hash_set<char> separators_;
|
|
bool normalize_repetition_;
|
|
};
|
|
|
|
class ProjectionTokenizer {
|
|
public:
|
|
explicit ProjectionTokenizer(const std::string& separators) {
|
|
InitializeSeparators(separators);
|
|
}
|
|
|
|
// Tokenizes the input by separators_. Limit to max_tokens, when it is not -1.
|
|
std::vector<std::string> Tokenize(const std::string& input, size_t max_input,
|
|
size_t max_tokens) const {
|
|
return Tokenize(input.c_str(), input.size(), max_input, max_tokens);
|
|
}
|
|
|
|
std::vector<std::string> Tokenize(const char* input_ptr, size_t len,
|
|
size_t max_input, size_t max_tokens) const {
|
|
// If separators_ is not given, tokenize the input with a space.
|
|
if (separators_.empty()) {
|
|
return SplitBySpace(input_ptr, len, max_input, max_tokens);
|
|
}
|
|
|
|
std::vector<std::string> tokens;
|
|
size_t last_index =
|
|
max_input == kEntireString ? len : (len < max_input ? len : max_input);
|
|
size_t start = 0;
|
|
// Skip leading spaces.
|
|
while (start < last_index && input_ptr[start] == kSpace) {
|
|
start++;
|
|
}
|
|
auto end = FindNextSeparator(input_ptr, start, last_index);
|
|
|
|
while (end != kInvalid &&
|
|
(max_tokens == kAllTokens || tokens.size() < max_tokens - 1)) {
|
|
auto length = end - start;
|
|
if (length > 0) tokens.emplace_back(input_ptr + start, length);
|
|
|
|
// Add the separator (except space and apostrophe) as a token
|
|
char separator = input_ptr[end];
|
|
if (separator != kSpace && separator != kApostrophe) {
|
|
tokens.emplace_back(input_ptr + end, 1);
|
|
}
|
|
|
|
start = end + (prepend_separator(separator) ? 0 : 1);
|
|
end = FindNextSeparator(input_ptr, end + 1, last_index);
|
|
}
|
|
auto length = end == kInvalid ? (last_index - start) : (end - start);
|
|
if (length > 0) tokens.emplace_back(input_ptr + start, length);
|
|
return tokens;
|
|
}
|
|
|
|
private:
|
|
// Parses and extracts supported separators.
|
|
void InitializeSeparators(const std::string& separators) {
|
|
for (int i = 0; i < separators.length(); ++i) {
|
|
separators_.insert(separators[i]);
|
|
}
|
|
}
|
|
|
|
// Starting from input_ptr[from], search for the next occurrence of
|
|
// separators_. Don't search beyond input_ptr[length](non-inclusive). Return
|
|
// -1 if not found.
|
|
size_t FindNextSeparator(const char* input_ptr, size_t from,
|
|
size_t length) const {
|
|
auto index = from;
|
|
while (index < length) {
|
|
char c = input_ptr[index];
|
|
// Do not break a number (e.g. "10,000", "0.23").
|
|
if (c == kComma || c == kDot) {
|
|
if (index + 1 < length && is_numeric(input_ptr[index + 1])) {
|
|
c = input_ptr[++index];
|
|
}
|
|
}
|
|
if (separators_.find(c) != separators_.end()) {
|
|
break;
|
|
}
|
|
++index;
|
|
}
|
|
return index == length ? kInvalid : index;
|
|
}
|
|
|
|
absl::flat_hash_set<char> separators_;
|
|
};
|
|
|
|
inline void StripTrailingAsciiPunctuation(std::string* str) {
|
|
auto it = std::find_if_not(str->rbegin(), str->rend(), ::ispunct);
|
|
str->erase(str->rend() - it);
|
|
}
|
|
|
|
std::string PreProcessString(const char* str, int len,
|
|
const bool remove_punctuation) {
|
|
std::string output_str(str, len);
|
|
std::transform(output_str.begin(), output_str.end(), output_str.begin(),
|
|
::tolower);
|
|
|
|
// Remove trailing punctuation.
|
|
if (remove_punctuation) {
|
|
StripTrailingAsciiPunctuation(&output_str);
|
|
}
|
|
|
|
if (output_str.empty()) {
|
|
output_str.assign(str, len);
|
|
}
|
|
return output_str;
|
|
}
|
|
|
|
bool ShouldIncludeCurrentNgram(const SkipGramParams& params, int size) {
|
|
if (size <= 0) {
|
|
return false;
|
|
}
|
|
if (params.include_all_ngrams) {
|
|
return size <= params.ngram_size;
|
|
} else {
|
|
return size == params.ngram_size;
|
|
}
|
|
}
|
|
|
|
bool ShouldStepInRecursion(const std::vector<int>& stack, int stack_idx,
|
|
int num_words, const SkipGramParams& params) {
|
|
// If current stack size and next word enumeration are within valid range.
|
|
if (stack_idx < params.ngram_size && stack[stack_idx] + 1 < num_words) {
|
|
// If this stack is empty, step in for first word enumeration.
|
|
if (stack_idx == 0) {
|
|
return true;
|
|
}
|
|
// If next word enumeration are within the range of max_skip_size.
|
|
// NOTE: equivalent to
|
|
// next_word_idx = stack[stack_idx] + 1
|
|
// next_word_idx - stack[stack_idx-1] <= max_skip_size + 1
|
|
if (stack[stack_idx] - stack[stack_idx - 1] <= params.max_skip_size) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
std::string JoinTokensBySpace(const std::vector<int>& stack, int stack_idx,
|
|
const std::vector<std::string>& tokens) {
|
|
int len = 0;
|
|
for (int i = 0; i < stack_idx; i++) {
|
|
len += tokens[stack[i]].size();
|
|
}
|
|
len += stack_idx - 1;
|
|
|
|
std::string res;
|
|
res.reserve(len);
|
|
res.append(tokens[stack[0]]);
|
|
for (int i = 1; i < stack_idx; i++) {
|
|
res.append(" ");
|
|
res.append(tokens[stack[i]]);
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
std::unordered_map<std::string, int> ExtractSkipGramsImpl(
|
|
const std::vector<std::string>& tokens, const SkipGramParams& params) {
|
|
// Ignore positional tokens.
|
|
static auto* blacklist = new std::unordered_set<std::string>({
|
|
kStartToken,
|
|
kEndToken,
|
|
kEmptyToken,
|
|
});
|
|
|
|
std::unordered_map<std::string, int> res;
|
|
|
|
// Stack stores the index of word used to generate ngram.
|
|
// The size of stack is the size of ngram.
|
|
std::vector<int> stack(params.ngram_size + 1, 0);
|
|
// Stack index that indicates which depth the recursion is operating at.
|
|
int stack_idx = 1;
|
|
int num_words = tokens.size();
|
|
|
|
while (stack_idx >= 0) {
|
|
if (ShouldStepInRecursion(stack, stack_idx, num_words, params)) {
|
|
// When current depth can fill with a new word
|
|
// and the new word is within the max range to skip,
|
|
// fill this word to stack, recurse into next depth.
|
|
stack[stack_idx]++;
|
|
stack_idx++;
|
|
stack[stack_idx] = stack[stack_idx - 1];
|
|
} else {
|
|
if (ShouldIncludeCurrentNgram(params, stack_idx)) {
|
|
// Add n-gram to tensor buffer when the stack has filled with enough
|
|
// words to generate the ngram.
|
|
std::string ngram = JoinTokensBySpace(stack, stack_idx, tokens);
|
|
if (blacklist->find(ngram) == blacklist->end()) {
|
|
res[ngram] = stack_idx;
|
|
}
|
|
}
|
|
// When current depth cannot fill with a valid new word,
|
|
// and not in last depth to generate ngram,
|
|
// step back to previous depth to iterate to next possible word.
|
|
stack_idx--;
|
|
}
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
std::unordered_map<std::string, int> ExtractSkipGrams(
|
|
const std::string& input, ProjectionTokenizer* tokenizer,
|
|
ProjectionNormalizer* normalizer, const SkipGramParams& params) {
|
|
// Normalize the input.
|
|
const std::string& normalized =
|
|
normalizer == nullptr
|
|
? input
|
|
: normalizer->Normalize(input, params.max_input_chars);
|
|
|
|
// Split sentence to words.
|
|
std::vector<std::string> tokens;
|
|
if (params.char_level) {
|
|
tokens = SplitByChar(normalized.data(), normalized.size(),
|
|
params.max_input_chars);
|
|
} else {
|
|
tokens = tokenizer->Tokenize(normalized.data(), normalized.size(),
|
|
params.max_input_chars, kAllTokens);
|
|
}
|
|
|
|
// Process tokens
|
|
for (int i = 0; i < tokens.size(); ++i) {
|
|
if (params.preprocess) {
|
|
tokens[i] = PreProcessString(tokens[i].data(), tokens[i].size(),
|
|
params.remove_punctuation);
|
|
}
|
|
}
|
|
|
|
tokens.insert(tokens.begin(), kStartToken);
|
|
tokens.insert(tokens.end(), kEndToken);
|
|
|
|
return ExtractSkipGramsImpl(tokens, params);
|
|
}
|
|
} // namespace
|
|
// Generates LSH projections for input strings. This uses the framework in
|
|
// `string_projection_base.h`, with the implementation details that the input is
|
|
// a string tensor of messages and the op will perform tokenization.
|
|
//
|
|
// Input:
|
|
// tensor[0]: Input message, string[...]
|
|
//
|
|
// Additional attributes:
|
|
// max_input_chars: int[]
|
|
// maximum number of input characters to use from each message.
|
|
// token_separators: string[]
|
|
// the list of separators used to tokenize the input.
|
|
// normalize_repetition: bool[]
|
|
// if true, remove repeated characters in tokens ('loool' -> 'lol').
|
|
|
|
static const int kInputMessage = 0;
|
|
|
|
class StringProjectionOp : public StringProjectionOpBase {
|
|
public:
|
|
explicit StringProjectionOp(const flexbuffers::Map& custom_options)
|
|
: StringProjectionOpBase(custom_options),
|
|
projection_normalizer_(
|
|
custom_options["token_separators"].AsString().str(),
|
|
custom_options["normalize_repetition"].AsBool()),
|
|
projection_tokenizer_(" ") {
|
|
if (custom_options["max_input_chars"].IsInt()) {
|
|
skip_gram_params().max_input_chars =
|
|
custom_options["max_input_chars"].AsInt32();
|
|
}
|
|
}
|
|
|
|
TfLiteStatus InitializeInput(TfLiteContext* context,
|
|
TfLiteNode* node) override {
|
|
input_ = &context->tensors[node->inputs->data[kInputMessage]];
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
std::unordered_map<std::string, int> ExtractSkipGrams(int i) override {
|
|
StringRef input = GetString(input_, i);
|
|
return ::tflite::ops::custom::libtextclassifier3::string_projection::
|
|
ExtractSkipGrams({input.str, static_cast<size_t>(input.len)},
|
|
&projection_tokenizer_, &projection_normalizer_,
|
|
skip_gram_params());
|
|
}
|
|
|
|
void FinalizeInput() override { input_ = nullptr; }
|
|
|
|
TfLiteIntArray* GetInputShape(TfLiteContext* context,
|
|
TfLiteNode* node) override {
|
|
return context->tensors[node->inputs->data[kInputMessage]].dims;
|
|
}
|
|
|
|
private:
|
|
ProjectionNormalizer projection_normalizer_;
|
|
ProjectionTokenizer projection_tokenizer_;
|
|
|
|
TfLiteTensor* input_;
|
|
};
|
|
|
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
|
return new StringProjectionOp(flexbuffers::GetRoot(buffer_t, length).AsMap());
|
|
}
|
|
|
|
} // namespace string_projection
|
|
|
|
// This op converts a list of strings to integers via LSH projections.
|
|
TfLiteRegistration* Register_STRING_PROJECTION() {
|
|
static TfLiteRegistration r = {libtextclassifier3::string_projection::Init,
|
|
libtextclassifier3::string_projection::Free,
|
|
libtextclassifier3::string_projection::Resize,
|
|
libtextclassifier3::string_projection::Eval};
|
|
return &r;
|
|
}
|
|
|
|
} // namespace libtextclassifier3
|
|
} // namespace custom
|
|
} // namespace ops
|
|
} // namespace tflite
|