You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
8.0 KiB
214 lines
8.0 KiB
/*
|
|
* Copyright (C) 2020 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <SampleDriverPartial.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "CompilationBuilder.h"
|
|
#include "ExecutionPlan.h"
|
|
#include "HalUtils.h"
|
|
#include "Manager.h"
|
|
#include "TestNeuralNetworksWrapper.h"
|
|
|
|
namespace android::nn {
|
|
namespace {
|
|
|
|
using sample_driver::SampleDriverPartial;
|
|
using Result = test_wrapper::Result;
|
|
using WrapperOperandType = test_wrapper::OperandType;
|
|
using WrapperCompilation = test_wrapper::Compilation;
|
|
using WrapperExecution = test_wrapper::Execution;
|
|
using WrapperType = test_wrapper::Type;
|
|
using WrapperModel = test_wrapper::Model;
|
|
|
|
class EmptyOperationResolver : public IOperationResolver {
|
|
public:
|
|
const OperationRegistration* findOperation(OperationType) const override { return nullptr; }
|
|
};
|
|
|
|
const char* kTestDriverName = "nnapi-test-sqrt-failing";
|
|
|
|
// A driver that only supports SQRT and fails during execution.
|
|
class FailingTestDriver : public SampleDriverPartial {
|
|
public:
|
|
// EmptyOperationResolver causes execution to fail.
|
|
FailingTestDriver() : SampleDriverPartial(kTestDriverName, &mEmptyOperationResolver) {}
|
|
|
|
hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
|
|
cb(V1_3::ErrorStatus::NONE, makeCapabilities(0.1)); // Faster than CPU.
|
|
return hardware::Void();
|
|
}
|
|
|
|
private:
|
|
std::vector<bool> getSupportedOperationsImpl(const V1_3::Model& model) const override {
|
|
std::vector<bool> supported(model.main.operations.size());
|
|
std::transform(model.main.operations.begin(), model.main.operations.end(),
|
|
supported.begin(), [](const V1_3::Operation& operation) {
|
|
return operation.type == V1_3::OperationType::SQRT;
|
|
});
|
|
return supported;
|
|
}
|
|
|
|
const EmptyOperationResolver mEmptyOperationResolver;
|
|
};
|
|
|
|
class FailingDriverTest : public ::testing::Test {
|
|
virtual void SetUp() {
|
|
DeviceManager* deviceManager = DeviceManager::get();
|
|
if (deviceManager->getUseCpuOnly() ||
|
|
!DeviceManager::partitioningAllowsFallback(deviceManager->getPartitioning())) {
|
|
GTEST_SKIP();
|
|
}
|
|
mTestDevice = DeviceManager::forTest_makeDriverDevice(
|
|
makeSharedDevice(kTestDriverName, new FailingTestDriver()));
|
|
deviceManager->forTest_setDevices({
|
|
mTestDevice,
|
|
DeviceManager::getCpuDevice(),
|
|
});
|
|
}
|
|
|
|
virtual void TearDown() { DeviceManager::get()->forTest_reInitializeDeviceList(); }
|
|
|
|
protected:
|
|
std::shared_ptr<Device> mTestDevice;
|
|
};
|
|
|
|
// Regression test for b/152623150.
|
|
TEST_F(FailingDriverTest, FailAfterInterpretedWhile) {
|
|
// Model:
|
|
// f = input0
|
|
// b = input1
|
|
// while CAST(b): # Identity cast.
|
|
// f = CAST(f)
|
|
// # FailingTestDriver fails here. When partial CPU fallback happens,
|
|
// # it should not loop forever.
|
|
// output0 = SQRT(f)
|
|
|
|
WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
|
|
WrapperOperandType boolType(WrapperType::TENSOR_BOOL8, {1});
|
|
|
|
WrapperModel conditionModel;
|
|
{
|
|
uint32_t f = conditionModel.addOperand(&floatType);
|
|
uint32_t b = conditionModel.addOperand(&boolType);
|
|
uint32_t out = conditionModel.addOperand(&boolType);
|
|
conditionModel.addOperation(ANEURALNETWORKS_CAST, {b}, {out});
|
|
conditionModel.identifyInputsAndOutputs({f, b}, {out});
|
|
ASSERT_EQ(conditionModel.finish(), Result::NO_ERROR);
|
|
ASSERT_TRUE(conditionModel.isValid());
|
|
}
|
|
|
|
WrapperModel bodyModel;
|
|
{
|
|
uint32_t f = bodyModel.addOperand(&floatType);
|
|
uint32_t b = bodyModel.addOperand(&boolType);
|
|
uint32_t out = bodyModel.addOperand(&floatType);
|
|
bodyModel.addOperation(ANEURALNETWORKS_CAST, {f}, {out});
|
|
bodyModel.identifyInputsAndOutputs({f, b}, {out});
|
|
ASSERT_EQ(bodyModel.finish(), Result::NO_ERROR);
|
|
ASSERT_TRUE(bodyModel.isValid());
|
|
}
|
|
|
|
WrapperModel model;
|
|
{
|
|
uint32_t fInput = model.addOperand(&floatType);
|
|
uint32_t bInput = model.addOperand(&boolType);
|
|
uint32_t fTmp = model.addOperand(&floatType);
|
|
uint32_t fSqrt = model.addOperand(&floatType);
|
|
uint32_t cond = model.addModelOperand(&conditionModel);
|
|
uint32_t body = model.addModelOperand(&bodyModel);
|
|
model.addOperation(ANEURALNETWORKS_WHILE, {cond, body, fInput, bInput}, {fTmp});
|
|
model.addOperation(ANEURALNETWORKS_SQRT, {fTmp}, {fSqrt});
|
|
model.identifyInputsAndOutputs({fInput, bInput}, {fSqrt});
|
|
ASSERT_TRUE(model.isValid());
|
|
ASSERT_EQ(model.finish(), Result::NO_ERROR);
|
|
}
|
|
|
|
WrapperCompilation compilation(&model);
|
|
ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
|
|
|
|
const CompilationBuilder* compilationBuilder =
|
|
reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
|
|
const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
|
|
const std::vector<std::shared_ptr<LogicalStep>>& steps = plan.forTest_compoundGetSteps();
|
|
ASSERT_EQ(steps.size(), 6u);
|
|
ASSERT_TRUE(steps[0]->isWhile());
|
|
ASSERT_TRUE(steps[1]->isExecution());
|
|
ASSERT_EQ(steps[1]->executionStep()->getDevice(), DeviceManager::getCpuDevice());
|
|
ASSERT_TRUE(steps[2]->isGoto());
|
|
ASSERT_TRUE(steps[3]->isExecution());
|
|
ASSERT_EQ(steps[3]->executionStep()->getDevice(), DeviceManager::getCpuDevice());
|
|
ASSERT_TRUE(steps[4]->isGoto());
|
|
ASSERT_TRUE(steps[5]->isExecution());
|
|
ASSERT_EQ(steps[5]->executionStep()->getDevice(), mTestDevice);
|
|
|
|
WrapperExecution execution(&compilation);
|
|
const float fInput[] = {12 * 12, 5 * 5};
|
|
const bool8 bInput = false;
|
|
float fSqrt[] = {0, 0};
|
|
ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
|
|
ASSERT_EQ(execution.setInput(1, &bInput), Result::NO_ERROR);
|
|
ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
|
|
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
|
|
ASSERT_EQ(fSqrt[0], 12);
|
|
ASSERT_EQ(fSqrt[1], 5);
|
|
}
|
|
|
|
// Regression test for b/155923033.
|
|
TEST_F(FailingDriverTest, SimplePlan) {
|
|
// Model:
|
|
// output0 = SQRT(input0)
|
|
//
|
|
// This results in a SIMPLE execution plan. When FailingTestDriver fails,
|
|
// partial CPU fallback should complete the execution.
|
|
|
|
WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
|
|
|
|
WrapperModel model;
|
|
{
|
|
uint32_t fInput = model.addOperand(&floatType);
|
|
uint32_t fSqrt = model.addOperand(&floatType);
|
|
model.addOperation(ANEURALNETWORKS_SQRT, {fInput}, {fSqrt});
|
|
model.identifyInputsAndOutputs({fInput}, {fSqrt});
|
|
ASSERT_TRUE(model.isValid());
|
|
ASSERT_EQ(model.finish(), Result::NO_ERROR);
|
|
}
|
|
|
|
WrapperCompilation compilation(&model);
|
|
ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
|
|
|
|
const CompilationBuilder* compilationBuilder =
|
|
reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
|
|
const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
|
|
ASSERT_TRUE(plan.isSimple());
|
|
|
|
WrapperExecution execution(&compilation);
|
|
const float fInput[] = {12 * 12, 5 * 5};
|
|
float fSqrt[] = {0, 0};
|
|
ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
|
|
ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
|
|
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
|
|
ASSERT_EQ(fSqrt[0], 12);
|
|
ASSERT_EQ(fSqrt[1], 5);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace android::nn
|