/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/grammar/semantics/composer.h" #include "utils/flatbuffers/flatbuffers.h" #include "utils/flatbuffers/reflection.h" #include "utils/grammar/parsing/derivation.h" #include "utils/grammar/parsing/parser.h" #include "utils/grammar/rules_generated.h" #include "utils/grammar/semantics/expression_generated.h" #include "utils/grammar/testing/utils.h" #include "utils/grammar/testing/value_generated.h" #include "utils/grammar/types.h" #include "utils/grammar/utils/rules.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace libtextclassifier3::grammar { namespace { using ::testing::ElementsAre; class SemanticComposerTest : public GrammarTest {}; TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) { RulesSetT model; grammar::LocaleShardMap locale_shard_map = grammar::LocaleShardMap::CreateLocaleShardMap({""}); Rules rules(locale_shard_map); const int test_value_type = TypeIdForName(semantic_values_schema_.get(), "libtextclassifier3.grammar.TestValue") .value(); { rules.Add("", {"january"}, static_cast(DefaultCallback::kSemanticExpression), /*callback_param=*/model.semantic_expression.size()); TestValueT value; value.value = 1; const std::string serialized_value = PackFlatbuffer(&value); ConstValueExpressionT const_value; const_value.base_type = reflection::BaseType::Obj; const_value.type = test_value_type; const_value.value.assign(serialized_value.begin(), serialized_value.end()); model.semantic_expression.emplace_back(new SemanticExpressionT); model.semantic_expression.back()->expression.Set(const_value); } { rules.Add("", {"february"}, static_cast(DefaultCallback::kSemanticExpression), /*callback_param=*/model.semantic_expression.size()); TestValueT value; value.value = 2; const std::string serialized_value = PackFlatbuffer(&value); ConstValueExpressionT const_value; const_value.base_type = reflection::BaseType::Obj; const_value.type = test_value_type; const_value.value.assign(serialized_value.begin(), serialized_value.end()); model.semantic_expression.emplace_back(new SemanticExpressionT); model.semantic_expression.back()->expression.Set(const_value); } const int kMonth = 0; rules.Add("", {""}, static_cast(DefaultCallback::kRootRule), kMonth); rules.Finalize().Serialize(/*include_debug_information=*/false, &model); const std::string model_buffer = PackFlatbuffer(&model); Parser parser(unilib_.get(), flatbuffers::GetRoot(model_buffer.data())); SemanticComposer composer(semantic_values_schema_.get()); { const TextContext text = TextContextForText("Month: January"); const std::vector derivations = parser.Parse(text, &arena_); EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 14))); StatusOr maybe_value = composer.Eval(text, derivations.front(), &arena_); EXPECT_TRUE(maybe_value.ok()); const TestValue* value = maybe_value.ValueOrDie()->Table(); EXPECT_EQ(value->value(), 1); } { const TextContext text = TextContextForText("Month: February"); const std::vector derivations = parser.Parse(text, &arena_); EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 15))); StatusOr maybe_value = composer.Eval(text, derivations.front(), &arena_); EXPECT_TRUE(maybe_value.ok()); const TestValue* value = maybe_value.ValueOrDie()->Table(); EXPECT_EQ(value->value(), 2); } } TEST_F(SemanticComposerTest, RecursivelyEvaluatesConstituents) { RulesSetT model; grammar::LocaleShardMap locale_shard_map = grammar::LocaleShardMap::CreateLocaleShardMap({""}); Rules rules(locale_shard_map); const int test_value_type = TypeIdForName(semantic_values_schema_.get(), "libtextclassifier3.grammar.TestValue") .value(); constexpr int kDateRule = 0; { rules.Add("", {"january"}, static_cast(DefaultCallback::kSemanticExpression), /*callback_param=*/model.semantic_expression.size()); TestValueT value; value.value = 42; const std::string serialized_value = PackFlatbuffer(&value); ConstValueExpressionT const_value; const_value.type = test_value_type; const_value.base_type = reflection::BaseType::Obj; const_value.value.assign(serialized_value.begin(), serialized_value.end()); model.semantic_expression.emplace_back(new SemanticExpressionT); model.semantic_expression.back()->expression.Set(const_value); } { // Define constituents of the rule. // TODO(smillius): Add support in the rules builder to directly specify // constituent ids in the rule, e.g. ` ::= @0? <4_digits>`. rules.Add("", {""}, static_cast(DefaultCallback::kMapping), /*callback_param=*/1); rules.Add("", {"?", "<4_digits>"}, static_cast(DefaultCallback::kSemanticExpression), /*callback_param=*/model.semantic_expression.size()); ConstituentExpressionT constituent; constituent.id = 1; model.semantic_expression.emplace_back(new SemanticExpressionT); model.semantic_expression.back()->expression.Set(constituent); rules.Add("", {""}, static_cast(DefaultCallback::kRootRule), /*callback_param=*/kDateRule); } rules.Finalize().Serialize(/*include_debug_information=*/false, &model); const std::string model_buffer = PackFlatbuffer(&model); Parser parser(unilib_.get(), flatbuffers::GetRoot(model_buffer.data())); SemanticComposer composer(semantic_values_schema_.get()); { const TextContext text = TextContextForText("Event: January 2020"); const std::vector derivations = ValidDeduplicatedDerivations(parser.Parse(text, &arena_)); EXPECT_THAT(derivations, ElementsAre(IsDerivation(kDateRule, 7, 19))); StatusOr maybe_value = composer.Eval(text, derivations.front(), &arena_); EXPECT_TRUE(maybe_value.ok()); const TestValue* value = maybe_value.ValueOrDie()->Table(); EXPECT_EQ(value->value(), 42); } } } // namespace } // namespace libtextclassifier3::grammar