You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

442 lines
17 KiB

/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/grammar/parsing/matcher.h"
#include <string>
#include <vector>
#include "utils/base/arena.h"
#include "utils/grammar/rules_generated.h"
#include "utils/grammar/types.h"
#include "utils/grammar/utils/rules.h"
#include "utils/strings/append.h"
#include "utils/utf8/unilib.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3::grammar {
namespace {
using ::testing::DescribeMatcher;
using ::testing::ElementsAre;
using ::testing::ExplainMatchResult;
using ::testing::IsEmpty;
struct TestMatchResult {
CodepointSpan codepoint_span;
std::string terminal;
std::string nonterminal;
int rule_id;
friend std::ostream& operator<<(std::ostream& os,
const TestMatchResult& match) {
return os << "Result(rule_id=" << match.rule_id
<< ", begin=" << match.codepoint_span.first
<< ", end=" << match.codepoint_span.second
<< ", terminal=" << match.terminal
<< ", nonterminal=" << match.nonterminal << ")";
}
};
MATCHER_P3(IsTerminal, begin, end, terminal,
"is terminal with begin that " +
DescribeMatcher<int>(begin, negation) + ", end that " +
DescribeMatcher<int>(end, negation) + ", value that " +
DescribeMatcher<std::string>(terminal, negation)) {
return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
result_listener) &&
ExplainMatchResult(terminal, arg.terminal, result_listener);
}
MATCHER_P3(IsNonterminal, begin, end, name,
"is nonterminal with begin that " +
DescribeMatcher<int>(begin, negation) + ", end that " +
DescribeMatcher<int>(end, negation) + ", name that " +
DescribeMatcher<std::string>(name, negation)) {
return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
result_listener) &&
ExplainMatchResult(name, arg.nonterminal, result_listener);
}
MATCHER_P4(IsDerivation, begin, end, name, rule_id,
"is derivation of rule that " +
DescribeMatcher<int>(rule_id, negation) + ", begin that " +
DescribeMatcher<int>(begin, negation) + ", end that " +
DescribeMatcher<int>(end, negation) + ", name that " +
DescribeMatcher<std::string>(name, negation)) {
return ExplainMatchResult(IsNonterminal(begin, end, name), arg,
result_listener) &&
ExplainMatchResult(rule_id, arg.rule_id, result_listener);
}
// Superclass of all tests.
class MatcherTest : public testing::Test {
protected:
MatcherTest()
: INIT_UNILIB_FOR_TESTING(unilib_), arena_(/*block_size=*/16 << 10) {}
std::string GetNonterminalName(
const RulesSet_::DebugInformation* debug_information,
const Nonterm nonterminal) const {
if (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry =
debug_information->nonterminal_names()->LookupByKey(nonterminal)) {
return entry->value()->str();
}
// Unnamed Nonterm.
return "()";
}
std::vector<TestMatchResult> GetMatchResults(
const Chart<>& chart,
const RulesSet_::DebugInformation* debug_information) {
std::vector<TestMatchResult> result;
for (const Derivation& derivation : chart.derivations()) {
result.emplace_back();
result.back().rule_id = derivation.rule_id;
result.back().codepoint_span = derivation.parse_tree->codepoint_span;
result.back().nonterminal =
GetNonterminalName(debug_information, derivation.parse_tree->lhs);
if (derivation.parse_tree->IsTerminalRule()) {
result.back().terminal = derivation.parse_tree->terminal;
}
}
return result;
}
UniLib unilib_;
UnsafeArena arena_;
};
TEST_F(MatcherTest, HandlesBasicOperations) {
// Create an example grammar.
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Rules rules(locale_shard_map);
rules.Add("<test>", {"the", "quick", "brown", "fox"},
static_cast<CallbackId>(DefaultCallback::kRootRule));
rules.Add("<action>", {"<test>"},
static_cast<CallbackId>(DefaultCallback::kRootRule));
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
/*include_debug_information=*/true);
const RulesSet* rules_set =
flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(0, 1, "the");
matcher.AddTerminal(1, 2, "quick");
matcher.AddTerminal(2, 3, "brown");
matcher.AddTerminal(3, 4, "fox");
EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(IsNonterminal(0, 4, "<test>"),
IsNonterminal(0, 4, "<action>")));
}
std::string CreateTestGrammar() {
// Create an example grammar.
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Rules rules(locale_shard_map);
// Callbacks on terminal rules.
rules.Add("<output_5>", {"quick"},
static_cast<CallbackId>(DefaultCallback::kRootRule), 6);
rules.Add("<output_0>", {"the"},
static_cast<CallbackId>(DefaultCallback::kRootRule), 1);
// Callbacks on non-terminal rules.
rules.Add("<output_1>", {"the", "quick", "brown", "fox"},
static_cast<CallbackId>(DefaultCallback::kRootRule), 2);
rules.Add("<output_2>", {"the", "quick"},
static_cast<CallbackId>(DefaultCallback::kRootRule), 3);
rules.Add("<output_3>", {"brown", "fox"},
static_cast<CallbackId>(DefaultCallback::kRootRule), 4);
// Now a complex thing: "the* brown fox".
rules.Add("<thestarbrownfox>", {"brown", "fox"},
static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
rules.Add("<thestarbrownfox>", {"the", "<thestarbrownfox>"},
static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
return rules.Finalize().SerializeAsFlatbuffer(
/*include_debug_information=*/true);
}
Nonterm FindNontermForName(const RulesSet* rules,
const std::string& nonterminal_name) {
for (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry :
*rules->debug_information()->nonterminal_names()) {
if (entry->value()->str() == nonterminal_name) {
return entry->key();
}
}
return kUnassignedNonterm;
}
TEST_F(MatcherTest, HandlesDerivationsOfRules) {
const std::string rules_buffer = CreateTestGrammar();
const RulesSet* rules_set =
flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(0, 1, "the");
matcher.AddTerminal(1, 2, "quick");
matcher.AddTerminal(2, 3, "brown");
matcher.AddTerminal(3, 4, "fox");
matcher.AddTerminal(3, 5, "fox");
matcher.AddTerminal(4, 6, "fox"); // Not adjacent to "brown".
EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(
// the
IsDerivation(0, 1, "<output_0>", 1),
// quick
IsDerivation(1, 2, "<output_5>", 6),
IsDerivation(0, 2, "<output_2>", 3),
// brown
// fox
IsDerivation(0, 4, "<output_1>", 2),
IsDerivation(2, 4, "<output_3>", 4),
IsDerivation(2, 4, "<thestarbrownfox>", 5),
// fox
IsDerivation(0, 5, "<output_1>", 2),
IsDerivation(2, 5, "<output_3>", 4),
IsDerivation(2, 5, "<thestarbrownfox>", 5)));
}
TEST_F(MatcherTest, HandlesRecursiveRules) {
const std::string rules_buffer = CreateTestGrammar();
const RulesSet* rules_set =
flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(0, 1, "the");
matcher.AddTerminal(1, 2, "the");
matcher.AddTerminal(2, 4, "the");
matcher.AddTerminal(3, 4, "the");
matcher.AddTerminal(4, 5, "brown");
matcher.AddTerminal(5, 6, "fox"); // Generates 5 of <thestarbrownfox>
EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(IsTerminal(0, 1, "the"), IsTerminal(1, 2, "the"),
IsTerminal(2, 4, "the"), IsTerminal(3, 4, "the"),
IsNonterminal(4, 6, "<output_3>"),
IsNonterminal(4, 6, "<thestarbrownfox>"),
IsNonterminal(3, 6, "<thestarbrownfox>"),
IsNonterminal(2, 6, "<thestarbrownfox>"),
IsNonterminal(1, 6, "<thestarbrownfox>"),
IsNonterminal(0, 6, "<thestarbrownfox>")));
}
TEST_F(MatcherTest, HandlesManualAddParseTreeCalls) {
const std::string rules_buffer = CreateTestGrammar();
const RulesSet* rules_set =
flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
Matcher matcher(&unilib_, rules_set, &arena_);
// Test having the lexer call AddParseTree() instead of AddTerminal()
matcher.AddTerminal(-4, 37, "the");
matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
FindNontermForName(rules_set, "<thestarbrownfox>"), CodepointSpan{37, 42},
/*match_offset=*/37, ParseTree::Type::kDefault));
EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(IsTerminal(-4, 37, "the"),
IsNonterminal(-4, 42, "<thestarbrownfox>")));
}
TEST_F(MatcherTest, HandlesOptionalRuleElements) {
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Rules rules(locale_shard_map);
rules.Add("<output_0>", {"a?", "b?", "c?", "d?", "e"},
static_cast<CallbackId>(DefaultCallback::kRootRule));
rules.Add("<output_1>", {"a", "b?", "c", "d?", "e"},
static_cast<CallbackId>(DefaultCallback::kRootRule));
rules.Add("<output_2>", {"a", "b?", "c", "d", "e?"},
static_cast<CallbackId>(DefaultCallback::kRootRule));
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
/*include_debug_information=*/true);
const RulesSet* rules_set =
flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
Matcher matcher(&unilib_, rules_set, &arena_);
// Run the matcher on "a b c d e".
matcher.AddTerminal(0, 1, "a");
matcher.AddTerminal(1, 2, "b");
matcher.AddTerminal(2, 3, "c");
matcher.AddTerminal(3, 4, "d");
matcher.AddTerminal(4, 5, "e");
EXPECT_THAT(
GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(
IsNonterminal(0, 4, "<output_2>"), IsTerminal(4, 5, "e"),
IsNonterminal(0, 5, "<output_0>"), IsNonterminal(0, 5, "<output_1>"),
IsNonterminal(0, 5, "<output_2>"), IsNonterminal(1, 5, "<output_0>"),
IsNonterminal(2, 5, "<output_0>"),
IsNonterminal(3, 5, "<output_0>")));
}
TEST_F(MatcherTest, HandlesWhitespaceGapLimits) {
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Rules rules(locale_shard_map);
rules.Add("<iata>", {"lx"});
rules.Add("<iata>", {"aa"});
// Require no whitespace between code and flight number.
rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
/*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
/*max_whitespace_gap=*/0);
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
/*include_debug_information=*/true);
const RulesSet* rules_set =
flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
// Check that the grammar triggers on LX1138.
{
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(0, 2, "LX");
matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
EXPECT_THAT(
GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
}
// Check that the grammar doesn't trigger on LX 1138.
{
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(6, 8, "LX");
matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
CodepointSpan{9, 13}, /*match_offset=*/8, ParseTree::Type::kDefault));
EXPECT_THAT(
GetMatchResults(matcher.chart(), rules_set->debug_information()),
IsEmpty());
}
}
TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) {
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Rules rules(locale_shard_map);
rules.Add("<iata>", {"LX"}, /*callback=*/kNoCallback, 0,
/*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
rules.Add("<iata>", {"AA"}, /*callback=*/kNoCallback, 0,
/*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
rules.Add("<iata>", {"dl"}, /*callback=*/kNoCallback, 0,
/*max_whitespace_gap*/ -1, /*case_sensitive=*/false);
// Require no whitespace between code and flight number.
rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
/*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
/*max_whitespace_gap=*/0);
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
/*include_debug_information=*/true);
const RulesSet* rules_set =
flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
// Check that the grammar triggers on LX1138.
{
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(0, 2, "LX");
matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
EXPECT_THAT(
GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
}
// Check that the grammar doesn't trigger on lx1138.
{
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(6, 8, "lx");
matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
}
// Check that the grammar does trigger on dl1138.
{
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(12, 14, "dl");
matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
CodepointSpan{14, 18}, /*match_offset=*/14, ParseTree::Type::kDefault));
EXPECT_THAT(
GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(IsNonterminal(12, 18, "<flight_number>")));
}
}
TEST_F(MatcherTest, HandlesExclusions) {
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Rules rules(locale_shard_map);
rules.Add("<all_zeros>", {"0000"});
rules.AddWithExclusion("<flight_code>", {"<4_digits>"},
/*excluded_nonterminal=*/"<all_zeros>");
rules.Add("<iata>", {"lx"});
rules.Add("<iata>", {"aa"});
rules.Add("<iata>", {"dl"});
// Require no whitespace between code and flight number.
rules.Add("<flight_number>", {"<iata>", "<flight_code>"},
static_cast<CallbackId>(DefaultCallback::kRootRule));
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
/*include_debug_information=*/true);
const RulesSet* rules_set =
flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
// Check that the grammar triggers on LX1138.
{
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(0, 2, "LX");
matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
matcher.Finish();
EXPECT_THAT(
GetMatchResults(matcher.chart(), rules_set->debug_information()),
ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
}
// Check that the grammar doesn't trigger on LX0000.
{
Matcher matcher(&unilib_, rules_set, &arena_);
matcher.AddTerminal(6, 8, "LX");
matcher.AddTerminal(8, 12, "0000");
matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
matcher.Finish();
EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
}
}
} // namespace
} // namespace libtextclassifier3::grammar