You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
5.9 KiB
147 lines
5.9 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2018 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
|
||
|
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
|
||
|
|
||
|
#include <utility>
|
||
|
|
||
|
#include "OperationsUtils.h"
|
||
|
|
||
|
namespace android {
|
||
|
namespace nn {
|
||
|
|
||
|
// Encapsulates an operation implementation.
|
||
|
struct OperationRegistration {
|
||
|
OperationType type;
|
||
|
const char* name;
|
||
|
|
||
|
// Validates operand types, shapes, and any values known during graph creation.
|
||
|
std::function<Result<Version>(const IOperationValidationContext*)> validate;
|
||
|
|
||
|
// prepare is called when the inputs this operation depends on have been
|
||
|
// computed. Typically, prepare does any remaining validation and sets
|
||
|
// output shapes via context->setOutputShape(...).
|
||
|
std::function<bool(IOperationExecutionContext*)> prepare;
|
||
|
|
||
|
// Executes the operation, reading from context->getInputBuffer(...)
|
||
|
// and writing to context->getOutputBuffer(...).
|
||
|
std::function<bool(IOperationExecutionContext*)> execute;
|
||
|
|
||
|
struct Flag {
|
||
|
// Whether the operation allows at least one operand to be omitted.
|
||
|
bool allowOmittedOperand = false;
|
||
|
// Whether the operation allows at least one input operand to be a zero-sized tensor.
|
||
|
bool allowZeroSizedInput = false;
|
||
|
} flags;
|
||
|
|
||
|
OperationRegistration(
|
||
|
OperationType type, const char* name,
|
||
|
std::function<Result<Version>(const IOperationValidationContext*)> validate,
|
||
|
std::function<bool(IOperationExecutionContext*)> prepare,
|
||
|
std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
|
||
|
: type(type),
|
||
|
name(name),
|
||
|
validate(std::move(validate)),
|
||
|
prepare(std::move(prepare)),
|
||
|
execute(std::move(execute)),
|
||
|
flags(flags) {}
|
||
|
};
|
||
|
|
||
|
// A registry of operation implementations.
|
||
|
class IOperationResolver {
|
||
|
public:
|
||
|
virtual const OperationRegistration* findOperation(OperationType operationType) const = 0;
|
||
|
virtual ~IOperationResolver() {}
|
||
|
};
|
||
|
|
||
|
// A registry of builtin operation implementations.
|
||
|
//
|
||
|
// Note that some operations bypass BuiltinOperationResolver (b/124041202).
|
||
|
//
|
||
|
// Usage:
|
||
|
// const OperationRegistration* operationRegistration =
|
||
|
// BuiltinOperationResolver::get()->findOperation(operationType);
|
||
|
// NN_RET_CHECK(operationRegistration != nullptr);
|
||
|
// NN_RET_CHECK(operationRegistration->validate != nullptr);
|
||
|
// NN_RET_CHECK(operationRegistration->validate(&context));
|
||
|
//
|
||
|
class BuiltinOperationResolver : public IOperationResolver {
|
||
|
DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
|
||
|
|
||
|
public:
|
||
|
static const BuiltinOperationResolver* get() {
|
||
|
static BuiltinOperationResolver instance;
|
||
|
return &instance;
|
||
|
}
|
||
|
|
||
|
const OperationRegistration* findOperation(OperationType operationType) const override;
|
||
|
|
||
|
// The number of operation types (OperationCode) defined in NeuralNetworks.h.
|
||
|
static constexpr int kNumberOfOperationTypes = 102;
|
||
|
|
||
|
private:
|
||
|
BuiltinOperationResolver();
|
||
|
|
||
|
void registerOperation(const OperationRegistration* operationRegistration);
|
||
|
|
||
|
const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
|
||
|
};
|
||
|
|
||
|
// NN_REGISTER_OPERATION creates OperationRegistration for consumption by
|
||
|
// OperationResolver.
|
||
|
//
|
||
|
// Usage:
|
||
|
// (check OperationRegistration::Flag for available fields and default values.)
|
||
|
//
|
||
|
// - With default flags.
|
||
|
// NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
|
||
|
// foo_op::prepare, foo_op::execute);
|
||
|
//
|
||
|
// - With a customized flag.
|
||
|
// NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
|
||
|
// foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
|
||
|
//
|
||
|
// - With multiple customized flags.
|
||
|
// NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
|
||
|
// foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
|
||
|
// .allowZeroSizedInput = true);
|
||
|
//
|
||
|
#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
|
||
|
#define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \
|
||
|
const OperationRegistration* register_##identifier() { \
|
||
|
static OperationRegistration registration(OperationType::identifier, operationName, \
|
||
|
validate, prepare, execute, {__VA_ARGS__}); \
|
||
|
return ®istration; \
|
||
|
}
|
||
|
#else
|
||
|
// This version ignores CPU execution logic (prepare and execute).
|
||
|
// The compiler is supposed to omit that code so that only validation logic
|
||
|
// makes it into libneuralnetworks_utils.
|
||
|
#define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
|
||
|
...) \
|
||
|
const OperationRegistration* register_##identifier() { \
|
||
|
static OperationRegistration registration(OperationType::identifier, operationName, \
|
||
|
validate, nullptr, nullptr, {__VA_ARGS__}); \
|
||
|
return ®istration; \
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
} // namespace nn
|
||
|
} // namespace android
|
||
|
|
||
|
#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
|