You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
321 lines
13 KiB
321 lines
13 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "utils/grammar/parsing/parser.h"
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "utils/grammar/parsing/derivation.h"
|
|
#include "utils/grammar/rules_generated.h"
|
|
#include "utils/grammar/testing/utils.h"
|
|
#include "utils/grammar/types.h"
|
|
#include "utils/grammar/utils/ir.h"
|
|
#include "utils/grammar/utils/rules.h"
|
|
#include "utils/i18n/locale.h"
|
|
#include "utils/tokenizer.h"
|
|
#include "utils/utf8/unicodetext.h"
|
|
#include "utils/utf8/unilib.h"
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
namespace libtextclassifier3::grammar {
|
|
namespace {
|
|
|
|
using ::testing::ElementsAre;
|
|
using ::testing::IsEmpty;
|
|
|
|
class ParserTest : public GrammarTest {};
|
|
|
|
TEST_F(ParserTest, ParsesSimpleRules) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
rules.Add("<day>", {"<2_digits>"});
|
|
rules.Add("<month>", {"<2_digits>"});
|
|
rules.Add("<year>", {"<4_digits>"});
|
|
constexpr int kDate = 0;
|
|
rules.Add("<date>", {"<year>", "/", "<month>", "/", "<day>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kDate);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("Event: 2020/05/08"), &arena_)),
|
|
ElementsAre(IsDerivation(kDate, 7, 17)));
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesEmptyInput) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
constexpr int kTest = 0;
|
|
rules.Add("<test>", {"test"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kTest);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(
|
|
parser.Parse(TextContextForText("Event: test"), &arena_)),
|
|
ElementsAre(IsDerivation(kTest, 7, 11)));
|
|
|
|
// Check that we bail out in case of empty input.
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(
|
|
parser.Parse(TextContextForText(""), &arena_)),
|
|
IsEmpty());
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(
|
|
parser.Parse(TextContextForText(" "), &arena_)),
|
|
IsEmpty());
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesUppercaseTokens) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
constexpr int kScriptedReply = 0;
|
|
rules.Add("<test>", {"please?", "reply", "<uppercase_token>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule),
|
|
kScriptedReply);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("Reply STOP to cancel."), &arena_)),
|
|
ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("Reply stop to cancel."), &arena_)),
|
|
IsEmpty());
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesAnchors) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
constexpr int kScriptedReply = 0;
|
|
rules.Add("<test>", {"<^>", "reply", "<uppercase_token>", "<$>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule),
|
|
kScriptedReply);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(
|
|
parser.Parse(TextContextForText("Reply STOP"), &arena_)),
|
|
ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("Please reply STOP to cancel."), &arena_)),
|
|
IsEmpty());
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesWordBreaks) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
rules.Add("<carrier>", {"lx"});
|
|
rules.Add("<carrier>", {"aa"});
|
|
constexpr int kFlight = 0;
|
|
rules.Add("<flight>", {"<carrier>", "<digits>", "<\b>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
// Make sure the grammar recognizes "LX 38".
|
|
EXPECT_THAT(
|
|
ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("My flight is: LX 38. Arriving later"), &arena_)),
|
|
ElementsAre(IsDerivation(kFlight, 14, 19)));
|
|
|
|
// Make sure the grammar doesn't trigger on "LX 38.00".
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(
|
|
parser.Parse(TextContextForText("LX 38.00"), &arena_)),
|
|
IsEmpty());
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesAnnotations) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
constexpr int kCallPhone = 0;
|
|
rules.Add("<flight>", {"dial", "<phone>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kCallPhone);
|
|
rules.BindAnnotation("<phone>", "phone");
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
TextContext context = TextContextForText("Please dial 911");
|
|
|
|
// Sanity check that we don't trigger if we don't feed the correct
|
|
// annotations.
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
|
|
IsEmpty());
|
|
|
|
// Create a phone annotion.
|
|
AnnotatedSpan phone_span;
|
|
phone_span.span = CodepointSpan{12, 15};
|
|
phone_span.classification.emplace_back("phone", 1.0);
|
|
context.annotations.push_back(phone_span);
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
|
|
ElementsAre(IsDerivation(kCallPhone, 7, 15)));
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesRegexAnnotators) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
rules.AddRegex("<code>",
|
|
"(\"([A-Za-z]+)\"|\\b\"?(?:[A-Z]+[0-9]*|[0-9])\"?\\b)");
|
|
constexpr int kScriptedReply = 0;
|
|
rules.Add("<test>", {"please?", "reply", "<code>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule),
|
|
kScriptedReply);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("Reply STOP to cancel."), &arena_)),
|
|
ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("Reply Stop to cancel."), &arena_)),
|
|
IsEmpty());
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesExclusions) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
rules.Add("<excluded>", {"be", "safe"});
|
|
rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
|
|
/*excluded_nonterminal=*/"<excluded>");
|
|
constexpr int kSetReminder = 0;
|
|
rules.Add("<set_reminder>",
|
|
{"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("do not forget to be there"), &arena_)),
|
|
ElementsAre(IsDerivation(kSetReminder, 0, 25)));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("do not forget to be safe"), &arena_)),
|
|
IsEmpty());
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesFillers) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
constexpr int kSetReminder = 0;
|
|
rules.Add("<set_reminder>", {"do", "not", "forget", "to", "<filler>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("do not forget to be there"), &arena_)),
|
|
ElementsAre(IsDerivation(kSetReminder, 0, 25)));
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesAssertions) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
rules.Add("<carrier>", {"lx"});
|
|
rules.Add("<carrier>", {"aa"});
|
|
rules.Add("<flight_code>", {"<2_digits>"});
|
|
rules.Add("<flight_code>", {"<3_digits>"});
|
|
rules.Add("<flight_code>", {"<4_digits>"});
|
|
// Flight: carrier + flight code and check right context.
|
|
constexpr int kFlight = 0;
|
|
rules.Add("<track_flight>",
|
|
{"<carrier>", "<flight_code>", "<context_assertion>?"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
|
|
// Exclude matches like: LX 38.00 etc.
|
|
rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
|
|
/*negative=*/true);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(
|
|
ValidDeduplicatedDerivations(
|
|
parser.Parse(TextContextForText("LX38 aa 44 LX 38.38"), &arena_)),
|
|
ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesWhitespaceGapLimit) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
rules.Add("<carrier>", {"lx"});
|
|
rules.Add("<carrier>", {"aa"});
|
|
rules.Add("<flight_code>", {"<2_digits>"});
|
|
rules.Add("<flight_code>", {"<3_digits>"});
|
|
rules.Add("<flight_code>", {"<4_digits>"});
|
|
// Flight: carrier + flight code and check right context.
|
|
constexpr int kFlight = 0;
|
|
rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight,
|
|
/*max_whitespace_gap=*/0);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
|
|
TextContextForText("LX38 aa 44 LX 38"), &arena_)),
|
|
ElementsAre(IsDerivation(kFlight, 0, 4)));
|
|
}
|
|
|
|
TEST_F(ParserTest, HandlesCaseSensitiveMatching) {
|
|
grammar::LocaleShardMap locale_shard_map =
|
|
grammar::LocaleShardMap::CreateLocaleShardMap({""});
|
|
Rules rules(locale_shard_map);
|
|
rules.Add("<carrier>", {"Lx"}, /*callback=*/kNoCallback, /*callback_param=*/0,
|
|
/*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
|
|
rules.Add("<carrier>", {"AA"}, /*callback=*/kNoCallback, /*callback_param=*/0,
|
|
/*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
|
|
rules.Add("<flight_code>", {"<2_digits>"});
|
|
rules.Add("<flight_code>", {"<3_digits>"});
|
|
rules.Add("<flight_code>", {"<4_digits>"});
|
|
// Flight: carrier + flight code and check right context.
|
|
constexpr int kFlight = 0;
|
|
rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
|
|
static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
|
|
const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
|
|
Parser parser(unilib_.get(),
|
|
flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
|
|
|
|
EXPECT_THAT(
|
|
ValidDeduplicatedDerivations(
|
|
parser.Parse(TextContextForText("Lx38 AA 44 LX 38"), &arena_)),
|
|
ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace libtextclassifier3::grammar
|