You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
194 lines
6.0 KiB
194 lines
6.0 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2017 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#include <gmock/gmock.h>
|
||
|
#include <gtest/gtest.h>
|
||
|
|
||
|
#include <vector>
|
||
|
|
||
|
#include "HashtableLookup.h"
|
||
|
#include "NeuralNetworksWrapper.h"
|
||
|
|
||
|
using ::testing::FloatNear;
|
||
|
using ::testing::Matcher;
|
||
|
|
||
|
namespace android {
|
||
|
namespace nn {
|
||
|
namespace wrapper {
|
||
|
|
||
|
namespace {
|
||
|
|
||
|
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
|
||
|
float max_abs_error = 1.e-6) {
|
||
|
std::vector<Matcher<float>> matchers;
|
||
|
matchers.reserve(values.size());
|
||
|
for (const float& v : values) {
|
||
|
matchers.emplace_back(FloatNear(v, max_abs_error));
|
||
|
}
|
||
|
return matchers;
|
||
|
}
|
||
|
|
||
|
} // namespace
|
||
|
|
||
|
using ::testing::ElementsAreArray;
|
||
|
|
||
|
#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
|
||
|
ACTION(Lookup, int) \
|
||
|
ACTION(Key, int) \
|
||
|
ACTION(Value, float)
|
||
|
|
||
|
// For all output and intermediate states
|
||
|
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
|
||
|
ACTION(Output, float) \
|
||
|
ACTION(Hits, uint8_t)
|
||
|
|
||
|
class HashtableLookupOpModel {
|
||
|
public:
|
||
|
HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,
|
||
|
std::initializer_list<uint32_t> key_shape,
|
||
|
std::initializer_list<uint32_t> value_shape) {
|
||
|
auto it_vs = value_shape.begin();
|
||
|
rows_ = *it_vs++;
|
||
|
features_ = *it_vs;
|
||
|
|
||
|
std::vector<uint32_t> inputs;
|
||
|
|
||
|
// Input and weights
|
||
|
OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
|
||
|
inputs.push_back(model_.addOperand(&LookupTy));
|
||
|
|
||
|
OperandType KeyTy(Type::TENSOR_INT32, key_shape);
|
||
|
inputs.push_back(model_.addOperand(&KeyTy));
|
||
|
|
||
|
OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
|
||
|
inputs.push_back(model_.addOperand(&ValueTy));
|
||
|
|
||
|
// Output and other intermediate state
|
||
|
std::vector<uint32_t> outputs;
|
||
|
|
||
|
std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
|
||
|
out_dim.push_back(features_);
|
||
|
|
||
|
OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
|
||
|
outputs.push_back(model_.addOperand(&OutputOpndTy));
|
||
|
|
||
|
OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
|
||
|
outputs.push_back(model_.addOperand(&HitsOpndTy));
|
||
|
|
||
|
auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
|
||
|
uint32_t sz = 1;
|
||
|
for (uint32_t d : dims) {
|
||
|
sz *= d;
|
||
|
}
|
||
|
return sz;
|
||
|
};
|
||
|
|
||
|
Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
|
||
|
Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
|
||
|
Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
|
||
|
|
||
|
model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
|
||
|
model_.identifyInputsAndOutputs(inputs, outputs);
|
||
|
|
||
|
model_.finish();
|
||
|
}
|
||
|
|
||
|
void Invoke() {
|
||
|
ASSERT_TRUE(model_.isValid());
|
||
|
|
||
|
Compilation compilation(&model_);
|
||
|
compilation.finish();
|
||
|
Execution execution(&compilation);
|
||
|
|
||
|
#define SetInputOrWeight(X, T) \
|
||
|
ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
|
||
|
sizeof(T) * X##_.size()), \
|
||
|
Result::NO_ERROR);
|
||
|
|
||
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
|
||
|
|
||
|
#undef SetInputOrWeight
|
||
|
|
||
|
#define SetOutput(X, T) \
|
||
|
ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
|
||
|
sizeof(T) * X##_.size()), \
|
||
|
Result::NO_ERROR);
|
||
|
|
||
|
FOR_ALL_OUTPUT_TENSORS(SetOutput);
|
||
|
|
||
|
#undef SetOutput
|
||
|
|
||
|
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
|
||
|
}
|
||
|
|
||
|
#define DefineSetter(X, T) \
|
||
|
void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
|
||
|
|
||
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
|
||
|
|
||
|
#undef DefineSetter
|
||
|
|
||
|
void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
|
||
|
for (uint32_t i = 0; i < rows_; i++) {
|
||
|
for (uint32_t j = 0; j < features_; j++) {
|
||
|
Value_[i * features_ + j] = function(i, j);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const std::vector<float>& GetOutput() const { return Output_; }
|
||
|
const std::vector<uint8_t>& GetHits() const { return Hits_; }
|
||
|
|
||
|
private:
|
||
|
Model model_;
|
||
|
uint32_t rows_;
|
||
|
uint32_t features_;
|
||
|
|
||
|
#define DefineTensor(X, T) std::vector<T> X##_;
|
||
|
|
||
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
|
||
|
FOR_ALL_OUTPUT_TENSORS(DefineTensor);
|
||
|
|
||
|
#undef DefineTensor
|
||
|
};
|
||
|
|
||
|
TEST(HashtableLookupOpTest, BlackBoxTest) {
|
||
|
HashtableLookupOpModel m({4}, {3}, {3, 2});
|
||
|
|
||
|
m.SetLookup({1234, -292, -11, 0});
|
||
|
m.SetKey({-11, 0, 1234});
|
||
|
m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
|
||
|
|
||
|
m.Invoke();
|
||
|
|
||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
||
|
2.0, 2.1, // 2-rd item
|
||
|
0, 0, // Not found
|
||
|
0.0, 0.1, // 0-th item
|
||
|
1.0, 1.1, // 1-st item
|
||
|
})));
|
||
|
EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
|
||
|
1,
|
||
|
0,
|
||
|
1,
|
||
|
1,
|
||
|
}));
|
||
|
}
|
||
|
|
||
|
} // namespace wrapper
|
||
|
} // namespace nn
|
||
|
} // namespace android
|