You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
2.6 KiB
71 lines
2.6 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
|
|
#define LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "absl/container/flat_hash_map.h"
|
|
#include "absl/container/flat_hash_set.h"
|
|
#include "absl/strings/string_view.h"
|
|
#include "tensorflow/lite/string_util.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
|
|
// SkipgramFinder finds skipgrams in strings.
|
|
//
|
|
// To use: First, add skipgrams using AddSkipgram() - each skipgram is
|
|
// associated with some category. Then, call FindSkipgrams() on a string,
|
|
// which will return the set of categories of the skipgrams in the string.
|
|
//
|
|
// Both the skipgrams and the input strings will be tokenzied by splitting
|
|
// on spaces. Additionally, the tokens will be lowercased and have any
|
|
// trailing punctuation removed.
|
|
class SkipgramFinder {
|
|
public:
|
|
explicit SkipgramFinder(int max_skip_size) : max_skip_size_(max_skip_size) {}
|
|
|
|
// Adds a skipgram that SkipgramFinder should look for in input strings.
|
|
// Tokens may use the regex '.*' as a suffix.
|
|
void AddSkipgram(const std::string& skipgram, int category);
|
|
|
|
// Find all of the skipgrams in `input`, and return their categories.
|
|
absl::flat_hash_set<int> FindSkipgrams(const std::string& input) const;
|
|
|
|
// Find all of the skipgrams in `tokens`, and return their categories.
|
|
absl::flat_hash_set<int> FindSkipgrams(
|
|
const std::vector<absl::string_view>& tokens) const;
|
|
absl::flat_hash_set<int> FindSkipgrams(
|
|
const std::vector<::tflite::StringRef>& tokens) const;
|
|
|
|
private:
|
|
struct TrieNode {
|
|
absl::flat_hash_set<int> categories;
|
|
// Maps tokens to the next node in the trie.
|
|
absl::flat_hash_map<std::string, TrieNode> token_to_node;
|
|
// Maps token prefixes (<prefix>.*) to the next node in the trie.
|
|
absl::flat_hash_map<std::string, TrieNode> prefix_to_node;
|
|
};
|
|
|
|
TrieNode skipgram_trie_;
|
|
int max_skip_size_;
|
|
};
|
|
|
|
} // namespace libtextclassifier3
|
|
#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
|