/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/feature-processor.h" #include "actions/actions_model_generated.h" #include "annotator/model-executor.h" #include "utils/tensor-view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace libtextclassifier3 { namespace { using ::testing::FloatEq; using ::testing::SizeIs; // EmbeddingExecutor that always returns features based on // the id of the sparse features. class FakeEmbeddingExecutor : public EmbeddingExecutor { public: bool AddEmbedding(const TensorView& sparse_features, float* dest, const int dest_size) const override { TC3_CHECK_GE(dest_size, 4); EXPECT_THAT(sparse_features, SizeIs(1)); dest[0] = sparse_features.data()[0]; dest[1] = sparse_features.data()[0]; dest[2] = -sparse_features.data()[0]; dest[3] = -sparse_features.data()[0]; return true; } private: std::vector storage_; }; class ActionsFeatureProcessorTest : public ::testing::Test { protected: ActionsFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} flatbuffers::DetachedBuffer PackFeatureProcessorOptions( ActionsTokenFeatureProcessorOptionsT* options) const { flatbuffers::FlatBufferBuilder builder; builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options)); return builder.Release(); } FakeEmbeddingExecutor embedding_executor_; UniLib unilib_; }; TEST_F(ActionsFeatureProcessorTest, TokenEmbeddings) { ActionsTokenFeatureProcessorOptionsT options; options.embedding_size = 4; options.tokenizer_options.reset(new ActionsTokenizerOptionsT); flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(&options); ActionsFeatureProcessor feature_processor( flatbuffers::GetRoot( options_fb.data()), &unilib_); Token token("aaa", 0, 3); std::vector token_features; EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_, &token_features)); EXPECT_THAT(token_features, SizeIs(4)); } TEST_F(ActionsFeatureProcessorTest, TokenEmbeddingsCaseFeature) { ActionsTokenFeatureProcessorOptionsT options; options.embedding_size = 4; options.extract_case_feature = true; options.tokenizer_options.reset(new ActionsTokenizerOptionsT); flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(&options); ActionsFeatureProcessor feature_processor( flatbuffers::GetRoot( options_fb.data()), &unilib_); Token token("Aaa", 0, 3); std::vector token_features; EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_, &token_features)); EXPECT_THAT(token_features, SizeIs(5)); EXPECT_THAT(token_features[4], FloatEq(1.0)); } TEST_F(ActionsFeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) { ActionsTokenFeatureProcessorOptionsT options; options.embedding_size = 4; options.extract_case_feature = true; options.tokenizer_options.reset(new ActionsTokenizerOptionsT); flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(&options); ActionsFeatureProcessor feature_processor( flatbuffers::GetRoot( options_fb.data()), &unilib_); const std::vector tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7), Token("Cccc", 8, 12)}; std::vector token_features; EXPECT_TRUE(feature_processor.AppendTokenFeatures( tokens, &embedding_executor_, &token_features)); EXPECT_THAT(token_features, SizeIs(15)); EXPECT_THAT(token_features[4], FloatEq(1.0)); EXPECT_THAT(token_features[9], FloatEq(-1.0)); EXPECT_THAT(token_features[14], FloatEq(1.0)); } } // namespace } // namespace libtextclassifier3