/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "HashtableLookup.h" #include "NeuralNetworksWrapper.h" using ::testing::FloatNear; using ::testing::Matcher; namespace android { namespace nn { namespace wrapper { namespace { std::vector> ArrayFloatNear(const std::vector& values, float max_abs_error = 1.e-6) { std::vector> matchers; matchers.reserve(values.size()); for (const float& v : values) { matchers.emplace_back(FloatNear(v, max_abs_error)); } return matchers; } } // namespace using ::testing::ElementsAreArray; #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ ACTION(Lookup, int) \ ACTION(Key, int) \ ACTION(Value, float) // For all output and intermediate states #define FOR_ALL_OUTPUT_TENSORS(ACTION) \ ACTION(Output, float) \ ACTION(Hits, uint8_t) class HashtableLookupOpModel { public: HashtableLookupOpModel(std::initializer_list lookup_shape, std::initializer_list key_shape, std::initializer_list value_shape) { auto it_vs = value_shape.begin(); rows_ = *it_vs++; features_ = *it_vs; std::vector inputs; // Input and weights OperandType LookupTy(Type::TENSOR_INT32, lookup_shape); inputs.push_back(model_.addOperand(&LookupTy)); OperandType KeyTy(Type::TENSOR_INT32, key_shape); inputs.push_back(model_.addOperand(&KeyTy)); OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape); inputs.push_back(model_.addOperand(&ValueTy)); // Output and other intermediate state std::vector outputs; std::vector out_dim(lookup_shape.begin(), lookup_shape.end()); out_dim.push_back(features_); OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim); outputs.push_back(model_.addOperand(&OutputOpndTy)); OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0); outputs.push_back(model_.addOperand(&HitsOpndTy)); auto multiAll = [](const std::vector& dims) -> uint32_t { uint32_t sz = 1; for (uint32_t d : dims) { sz *= d; } return sz; }; Value_.insert(Value_.end(), multiAll(value_shape), 0.f); Output_.insert(Output_.end(), multiAll(out_dim), 0.f); Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0); model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs); model_.identifyInputsAndOutputs(inputs, outputs); model_.finish(); } void Invoke() { ASSERT_TRUE(model_.isValid()); Compilation compilation(&model_); compilation.finish(); Execution execution(&compilation); #define SetInputOrWeight(X, T) \ ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \ sizeof(T) * X##_.size()), \ Result::NO_ERROR); FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); #undef SetInputOrWeight #define SetOutput(X, T) \ ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \ sizeof(T) * X##_.size()), \ Result::NO_ERROR); FOR_ALL_OUTPUT_TENSORS(SetOutput); #undef SetOutput ASSERT_EQ(execution.compute(), Result::NO_ERROR); } #define DefineSetter(X, T) \ void Set##X(const std::vector& f) { X##_.insert(X##_.end(), f.begin(), f.end()); } FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); #undef DefineSetter void SetHashtableValue(const std::function& function) { for (uint32_t i = 0; i < rows_; i++) { for (uint32_t j = 0; j < features_; j++) { Value_[i * features_ + j] = function(i, j); } } } const std::vector& GetOutput() const { return Output_; } const std::vector& GetHits() const { return Hits_; } private: Model model_; uint32_t rows_; uint32_t features_; #define DefineTensor(X, T) std::vector X##_; FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); FOR_ALL_OUTPUT_TENSORS(DefineTensor); #undef DefineTensor }; TEST(HashtableLookupOpTest, BlackBoxTest) { HashtableLookupOpModel m({4}, {3}, {3, 2}); m.SetLookup({1234, -292, -11, 0}); m.SetKey({-11, 0, 1234}); m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 2.0, 2.1, // 2-rd item 0, 0, // Not found 0.0, 0.1, // 0-th item 1.0, 1.1, // 1-st item }))); EXPECT_EQ(m.GetHits(), std::vector({ 1, 0, 1, 1, })); } } // namespace wrapper } // namespace nn } // namespace android