You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
176 lines
5.6 KiB
176 lines
5.6 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2017 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#include <gmock/gmock.h>
|
||
|
#include <gtest/gtest.h>
|
||
|
|
||
|
#include <vector>
|
||
|
|
||
|
#include "LSHProjection.h"
|
||
|
#include "NeuralNetworksWrapper.h"
|
||
|
|
||
|
using ::testing::FloatNear;
|
||
|
using ::testing::Matcher;
|
||
|
|
||
|
namespace android {
|
||
|
namespace nn {
|
||
|
namespace wrapper {
|
||
|
|
||
|
using ::testing::ElementsAre;
|
||
|
|
||
|
#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
|
||
|
ACTION(Hash, float) \
|
||
|
ACTION(Input, int) \
|
||
|
ACTION(Weight, float)
|
||
|
|
||
|
// For all output and intermediate states
|
||
|
#define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, int)
|
||
|
|
||
|
class LSHProjectionOpModel {
|
||
|
public:
|
||
|
LSHProjectionOpModel(LSHProjectionType type, std::initializer_list<uint32_t> hash_shape,
|
||
|
std::initializer_list<uint32_t> input_shape,
|
||
|
std::initializer_list<uint32_t> weight_shape)
|
||
|
: type_(type) {
|
||
|
std::vector<uint32_t> inputs;
|
||
|
|
||
|
OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape);
|
||
|
inputs.push_back(model_.addOperand(&HashTy));
|
||
|
OperandType InputTy(Type::TENSOR_INT32, input_shape);
|
||
|
inputs.push_back(model_.addOperand(&InputTy));
|
||
|
OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape);
|
||
|
inputs.push_back(model_.addOperand(&WeightTy));
|
||
|
|
||
|
OperandType TypeParamTy(Type::INT32, {});
|
||
|
inputs.push_back(model_.addOperand(&TypeParamTy));
|
||
|
|
||
|
std::vector<uint32_t> outputs;
|
||
|
|
||
|
auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
|
||
|
uint32_t sz = 1;
|
||
|
for (uint32_t d : dims) {
|
||
|
sz *= d;
|
||
|
}
|
||
|
return sz;
|
||
|
};
|
||
|
|
||
|
uint32_t outShapeDimension = 0;
|
||
|
if (type == LSHProjectionType_SPARSE || type == LSHProjectionType_SPARSE_DEPRECATED) {
|
||
|
auto it = hash_shape.begin();
|
||
|
Output_.insert(Output_.end(), *it, 0.f);
|
||
|
outShapeDimension = *it;
|
||
|
} else {
|
||
|
Output_.insert(Output_.end(), multiAll(hash_shape), 0.f);
|
||
|
outShapeDimension = multiAll(hash_shape);
|
||
|
}
|
||
|
|
||
|
OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension});
|
||
|
outputs.push_back(model_.addOperand(&OutputTy));
|
||
|
|
||
|
model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs);
|
||
|
model_.identifyInputsAndOutputs(inputs, outputs);
|
||
|
|
||
|
model_.finish();
|
||
|
}
|
||
|
|
||
|
#define DefineSetter(X, T) \
|
||
|
void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
|
||
|
|
||
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
|
||
|
|
||
|
#undef DefineSetter
|
||
|
|
||
|
const std::vector<int>& GetOutput() const { return Output_; }
|
||
|
|
||
|
void Invoke() {
|
||
|
ASSERT_TRUE(model_.isValid());
|
||
|
|
||
|
Compilation compilation(&model_);
|
||
|
compilation.finish();
|
||
|
Execution execution(&compilation);
|
||
|
|
||
|
#define SetInputOrWeight(X, T) \
|
||
|
ASSERT_EQ( \
|
||
|
execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), sizeof(T) * X##_.size()), \
|
||
|
Result::NO_ERROR);
|
||
|
|
||
|
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
|
||
|
|
||
|
#undef SetInputOrWeight
|
||
|
|
||
|
#define SetOutput(X, T) \
|
||
|
ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \
|
||
|
sizeof(T) * X##_.size()), \
|
||
|
Result::NO_ERROR);
|
||
|
|
||
|
FOR_ALL_OUTPUT_TENSORS(SetOutput);
|
||
|
|
||
|
#undef SetOutput
|
||
|
|
||
|
ASSERT_EQ(execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)),
|
||
|
Result::NO_ERROR);
|
||
|
|
||
|
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
Model model_;
|
||
|
LSHProjectionType type_;
|
||
|
|
||
|
std::vector<float> Hash_;
|
||
|
std::vector<int> Input_;
|
||
|
std::vector<float> Weight_;
|
||
|
std::vector<int> Output_;
|
||
|
}; // namespace wrapper
|
||
|
|
||
|
TEST(LSHProjectionOpTest2, DenseWithThreeInputs) {
|
||
|
LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3});
|
||
|
|
||
|
m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
|
||
|
m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
|
||
|
m.SetWeight({0.12, 0.34, 0.56});
|
||
|
|
||
|
m.Invoke();
|
||
|
|
||
|
EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0));
|
||
|
}
|
||
|
|
||
|
TEST(LSHProjectionOpTest2, SparseDeprecatedWithTwoInputs) {
|
||
|
LSHProjectionOpModel m(LSHProjectionType_SPARSE_DEPRECATED, {4, 2}, {3, 2}, {0});
|
||
|
|
||
|
m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
|
||
|
m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
|
||
|
|
||
|
m.Invoke();
|
||
|
|
||
|
EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0));
|
||
|
}
|
||
|
|
||
|
TEST(LSHProjectionOpTest2, SparseWithTwoInputs) {
|
||
|
LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {0});
|
||
|
|
||
|
m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
|
||
|
m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
|
||
|
|
||
|
m.Invoke();
|
||
|
|
||
|
EXPECT_THAT(m.GetOutput(), ElementsAre(1, 6, 10, 12));
|
||
|
}
|
||
|
|
||
|
} // namespace wrapper
|
||
|
} // namespace nn
|
||
|
} // namespace android
|