You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
4.9 KiB
152 lines
4.9 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "utils/sentencepiece/normalizer.h"
|
|
|
|
#include "utils/base/logging.h"
|
|
#include "utils/strings/utf8.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
|
|
bool SentencePieceNormalizer::Normalize(StringPiece input,
|
|
std::string* normalized_input) const {
|
|
// Ignores heading space.
|
|
if (remove_extra_whitespaces_) {
|
|
while (!input.empty()) {
|
|
std::pair<StringPiece, int> suffix_and_length;
|
|
if (!NormalizePrefix(input, &suffix_and_length)) {
|
|
TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
|
|
return false;
|
|
}
|
|
if (suffix_and_length.second <= 0) {
|
|
TC3_LOG(ERROR) << "Consumed string is empty.";
|
|
return false;
|
|
}
|
|
if (suffix_and_length.first.size() != 1 ||
|
|
suffix_and_length.first[0] != ' ') {
|
|
break;
|
|
}
|
|
input.RemovePrefix(suffix_and_length.second);
|
|
}
|
|
}
|
|
|
|
if (input.empty()) {
|
|
*normalized_input = "";
|
|
return true;
|
|
}
|
|
|
|
// Reserves the output buffer to avoid re-allocations.
|
|
const int kReservedSize = input.size() * 3;
|
|
normalized_input->reserve(kReservedSize);
|
|
|
|
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK)
|
|
// if escape_whitespaces() is set (default = true).
|
|
const StringPiece kSpaceSymbol = "\xe2\x96\x81";
|
|
|
|
// Adds a space symbol as a prefix (default is true)
|
|
// With this prefix, "world" and "hello world" are converted into
|
|
// "_world" and "_hello_world", which help the trainer to extract
|
|
// "_world" as one symbol.
|
|
if (add_dummy_prefix_) {
|
|
if (escape_whitespaces_) {
|
|
normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
|
|
} else {
|
|
normalized_input->append(" ");
|
|
}
|
|
}
|
|
|
|
bool is_prev_space = remove_extra_whitespaces_;
|
|
while (!input.empty()) {
|
|
std::pair<StringPiece, int> p;
|
|
if (!NormalizePrefix(input, &p)) {
|
|
TC3_LOG(ERROR) << "Couldn't normalize string.";
|
|
return false;
|
|
}
|
|
if (p.second <= 0) {
|
|
TC3_LOG(ERROR) << "Consumed string is empty.";
|
|
return false;
|
|
}
|
|
|
|
StringPiece sp = p.first;
|
|
|
|
// Removes heading spaces in sentence piece,
|
|
// if the previous sentence piece ends with whitespace.
|
|
while (is_prev_space && ConsumePrefix(&sp, " ")) {
|
|
}
|
|
|
|
if (!sp.empty()) {
|
|
const char* data = sp.data();
|
|
for (int n = 0; n < sp.size(); ++n) {
|
|
if (escape_whitespaces_ && data[n] == ' ') {
|
|
normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
|
|
} else {
|
|
*normalized_input += data[n];
|
|
}
|
|
}
|
|
// Checks whether the last character of sp is whitespace.
|
|
is_prev_space = EndsWith(sp, " ");
|
|
}
|
|
input.RemovePrefix(p.second);
|
|
is_prev_space = is_prev_space && remove_extra_whitespaces_;
|
|
}
|
|
|
|
// Ignores tailing space.
|
|
if (remove_extra_whitespaces_) {
|
|
const StringPiece space = escape_whitespaces_ ? kSpaceSymbol : " ";
|
|
while (EndsWith(*normalized_input, space)) {
|
|
const int length = normalized_input->size() - space.size();
|
|
normalized_input->resize(length);
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool SentencePieceNormalizer::NormalizePrefix(
|
|
StringPiece input, std::pair<StringPiece, int>* prefix) const {
|
|
if (input.empty()) return true;
|
|
StringSet::Match match;
|
|
if (!charsmap_trie_.LongestPrefixMatch(input, &match)) {
|
|
TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
|
|
return false;
|
|
}
|
|
const bool no_match = match.match_length <= 0;
|
|
if (no_match) {
|
|
int char_length;
|
|
if (!IsValidChar(input.data(), input.size(), &char_length)) {
|
|
// Found a malformed utf8.
|
|
// The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
|
|
// which is a valid Unicode of three bytes in utf8,
|
|
// but here we only consume one byte.
|
|
static const char kReplacementChar[] = "\xEF\xBF\xBD";
|
|
prefix->first = StringPiece(kReplacementChar, 3);
|
|
prefix->second = 1; // Consumes 1 byte, buts emit 0xFFFD.
|
|
} else {
|
|
prefix->first = StringPiece(input.data(), char_length);
|
|
prefix->second = char_length;
|
|
}
|
|
} else {
|
|
if (match.id < 0 || match.id >= charsmap_normalized_.size()) {
|
|
TC3_LOG(ERROR) << "Invalid entry in normalization table.";
|
|
return false;
|
|
}
|
|
prefix->first = StringPiece(&charsmap_normalized_.data()[match.id]);
|
|
prefix->second = match.match_length;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace libtextclassifier3
|