You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
5.8 KiB
153 lines
5.8 KiB
/*
|
|
* Copyright (C) 2019 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
|
|
#define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
|
|
|
|
#include <utility>
|
|
#include <variant>
|
|
#include <vector>
|
|
|
|
#include "NeuralNetworksExtensions.h"
|
|
#include "NeuralNetworksWrapper.h"
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
namespace extension_wrapper {
|
|
|
|
using wrapper::SymmPerChannelQuantParams;
|
|
using wrapper::Type;
|
|
|
|
struct ExtensionOperandParams {
|
|
std::vector<uint8_t> data;
|
|
|
|
ExtensionOperandParams(std::vector<uint8_t> data) : data(std::move(data)) {}
|
|
|
|
template <typename T>
|
|
ExtensionOperandParams(const T& data)
|
|
: ExtensionOperandParams(
|
|
std::vector(reinterpret_cast<const uint8_t*>(&data),
|
|
reinterpret_cast<const uint8_t*>(&data) + sizeof(data))) {
|
|
static_assert(std::is_trivially_copyable<T>::value, "data must be trivially copyable");
|
|
}
|
|
};
|
|
|
|
struct OperandType {
|
|
using ExtraParams =
|
|
std::variant<std::monostate, SymmPerChannelQuantParams, ExtensionOperandParams>;
|
|
|
|
ANeuralNetworksOperandType operandType;
|
|
std::vector<uint32_t> dimensions;
|
|
ExtraParams extraParams;
|
|
|
|
OperandType(const OperandType& other)
|
|
: operandType(other.operandType),
|
|
dimensions(other.dimensions),
|
|
extraParams(other.extraParams) {
|
|
operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
|
|
}
|
|
|
|
OperandType& operator=(const OperandType& other) {
|
|
if (this != &other) {
|
|
operandType = other.operandType;
|
|
dimensions = other.dimensions;
|
|
extraParams = other.extraParams;
|
|
operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0,
|
|
ExtraParams&& extraParams = std::monostate())
|
|
: dimensions(std::move(d)), extraParams(std::move(extraParams)) {
|
|
operandType = {
|
|
.type = static_cast<int32_t>(type),
|
|
.dimensionCount = static_cast<uint32_t>(dimensions.size()),
|
|
.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr,
|
|
.scale = scale,
|
|
.zeroPoint = zeroPoint,
|
|
};
|
|
}
|
|
|
|
OperandType(Type type, std::vector<uint32_t> dimensions, float scale, int32_t zeroPoint,
|
|
SymmPerChannelQuantParams&& channelQuant)
|
|
: OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {}
|
|
|
|
OperandType(Type type, std::vector<uint32_t> dimensions, ExtraParams&& extraParams)
|
|
: OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {}
|
|
};
|
|
|
|
class Model : public wrapper::Model {
|
|
public:
|
|
using wrapper::Model::Model; // Inherit constructors.
|
|
|
|
int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) {
|
|
int32_t result;
|
|
if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension,
|
|
&result) != ANEURALNETWORKS_NO_ERROR) {
|
|
mValid = false;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName,
|
|
uint16_t typeWithinExtension) {
|
|
ANeuralNetworksOperationType result;
|
|
if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName,
|
|
typeWithinExtension,
|
|
&result) != ANEURALNETWORKS_NO_ERROR) {
|
|
mValid = false;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
uint32_t addOperand(const OperandType* type) {
|
|
if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
|
|
ANEURALNETWORKS_NO_ERROR) {
|
|
mValid = false;
|
|
}
|
|
if (std::holds_alternative<SymmPerChannelQuantParams>(type->extraParams)) {
|
|
const auto& channelQuant = std::get<SymmPerChannelQuantParams>(type->extraParams);
|
|
if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
|
|
mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) {
|
|
mValid = false;
|
|
}
|
|
} else if (std::holds_alternative<ExtensionOperandParams>(type->extraParams)) {
|
|
const auto& extension = std::get<ExtensionOperandParams>(type->extraParams);
|
|
if (ANeuralNetworksModel_setOperandExtensionData(
|
|
mModel, mNextOperandId, extension.data.data(), extension.data.size()) !=
|
|
ANEURALNETWORKS_NO_ERROR) {
|
|
mValid = false;
|
|
}
|
|
}
|
|
return mNextOperandId++;
|
|
}
|
|
};
|
|
|
|
} // namespace extension_wrapper
|
|
|
|
namespace wrapper {
|
|
|
|
using ExtensionModel = extension_wrapper::Model;
|
|
using ExtensionOperandType = extension_wrapper::OperandType;
|
|
using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams;
|
|
|
|
} // namespace wrapper
|
|
} // namespace nn
|
|
} // namespace android
|
|
|
|
#endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
|