You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

214 lines
6.8 KiB

/**
* Copyright (c) Hisilicon Technologies Co., Ltd.. 2020-2020. All rights reserved.
*
*/
#ifndef __AI_MODEL_MANAGER_H__
#define __AI_MODEL_MANAGER_H__
#include <utils/RefBase.h>
#include <vendor/huanglong/hardware/hwhlai/1.0/types.h>
#include <string>
#include <vector>
#include <utils/Mutex.h>
#include "AiConfig.h"
namespace ai {
using ::std::vector;
using ::std::string;
using ::android::sp;
using ::android::RefBase;
using namespace ::vendor::huanglong::hardware::hwhlai::V1_0;
class HLAI_API_EXPORT AiModelManagerListener : public virtual RefBase
{
public:
AiModelManagerListener();
virtual ~AiModelManagerListener();
// V1.0
virtual void onBuildDone() {}
virtual void onStartDone() {}
// For Microsoft translation
virtual void onSetInputsAndOutputsDone() {}
virtual void onStartComputeDone() {}
virtual void onRunDone(vector<native_handle_t*> destDataVec,
vector<TensorDescription> destTensorVec)
{
UNUSED_PARAMETER(destDataVec);
UNUSED_PARAMETER(destTensorVec);
}
virtual void onStopDone() {}
virtual void onTimeout(vector<native_handle_t*> srcDataVec) { UNUSED_PARAMETER(srcDataVec); }
virtual void onError(int32_t errCode) { UNUSED_PARAMETER(errCode); }
virtual void onServiceDied() {}
long getVersion();
void registerVersion(const void* object, long version);
void unregisterVersion(const void* object);
long queryVersion(const void* object);
private:
android::Mutex mMutex;
};
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Woverloaded-virtual"
class HLAI_API_EXPORT AiModelManagerListenerV2 : public AiModelManagerListener
{
public:
AiModelManagerListenerV2();
~AiModelManagerListenerV2();
virtual void onStartDone(int32_t taskStamp) { UNUSED_PARAMETER(taskStamp); }
virtual void onRunDone(int32_t taskStamp) { UNUSED_PARAMETER(taskStamp); }
virtual void onStopDone(int32_t taskStamp) { UNUSED_PARAMETER(taskStamp); }
virtual void onTimeout(int32_t taskStamp) { UNUSED_PARAMETER(taskStamp); }
virtual void onError(int32_t taskStamp, int32_t errCode)
{
UNUSED_PARAMETER(taskStamp);
UNUSED_PARAMETER(errCode);
}
};
#pragma clang diagnostic pop
class AiModelManagerImpl; // forward declaration
class HLAI_API_EXPORT AiModelManager : public virtual RefBase
{
public:
typedef struct {
vector<TensorDescription> inputTensor;
vector<TensorDescription> outputTensor;
} ModelTensorInfo;
typedef struct {
vector<::android::hardware::hidl_string> inputName;
vector<::android::hardware::hidl_string> outputName;
} ModelTensorName;
typedef struct {
vector<int> batch;
} ModelDynamicBatch;
typedef struct {
vector<int> height;
vector<int> width;
} ModelDynamicHW;
/* this is a sync AiModelManager */
AiModelManager();
/* this is a async AiModelManager, you should give a listener to it */
AiModelManager(const sp<AiModelManagerListener>& listener);
~AiModelManager();
public:
/* QUERY FUNCS */
int32_t checkModelInfoValid(ModelInfo& modelInfo);
bool isModelDescValid(ModelDescription& modelPara);
bool isModelDescVecValid(vector<ModelDescription>& modelParaVec);
bool isModelBufferValid(ModelBuffer& modelBuffer);
bool isModelBufferVecValid(vector<ModelBuffer>& modelBufferVec);
ModelTensorInfo getModelTensor(const string& modelName);
ModelTensorName getModelTensorName(const string& modelName);
ModelDynamicBatch getModelDynamicBatch(const string& modelName);
ModelDynamicHW getModelDynamicHW(const string& modelName);
int32_t getMaxUsedMemory();
bool isServiceDied();
/* OPERATION FUNCS V1.0 */
int32_t buildModel(ModelInfo& modelInfo, string buildPath);
int32_t startModel(ModelDescription& modelPara);
int32_t startModel(vector<ModelDescription>& modelParaVec);
int32_t startModel(ModelBuffer& modelBuffer);
int32_t startModel(vector<ModelBuffer>& modelBufferVec);
int32_t stopModel();
public:
/* OPERATION FUNCS V1.5 */
int32_t startModel(ModelDescription& modelPara, int32_t* taskStamp);
int32_t startModel(vector<ModelDescription>& modelParaVec, int32_t* taskStamp);
int32_t startModel(ModelBuffer& modelBuffer, int32_t* taskStamp);
int32_t startModel(vector<ModelBuffer>& modelBufferVec, int32_t* taskStamp);
int32_t stopModel(int32_t* taskStamp);
public:
/* FOR MICROSOFT */
int32_t setInputsAndOutputs(string modelname, vector<AINeuralNetworkBuffer>& nn_inputs,
vector<AINeuralNetworkBuffer>& nn_outputs);
int32_t startCompute(string modelname);
private:
AiModelManagerImpl* mImpl;
};
class HLAI_API_EXPORT NativeHandleWrapper : public virtual RefBase
{
public:
static sp<NativeHandleWrapper> createFromTensor(TensorDescription tensor);
static sp<NativeHandleWrapper> createFromTensorWithSize(TensorDescription tensor, int size);
static sp<NativeHandleWrapper> createByFd(int fd, size_t offset, size_t size);
static sp<NativeHandleWrapper> createAippFromTensorWithSize(TensorDescription tensor, int size);
static sp<NativeHandleWrapper> createFromHandleTensor(native_handle_t* handle, TensorDescription tensor);
static sp<NativeHandleWrapper> createFromSize(int size);
int getSize() const;
void* getBuffer() const;
native_handle_t* getHandle() const;
const TensorDescription& getTensor() const;
size_t getOffset() const;
protected:
NativeHandleWrapper(native_handle_t* handle, TensorDescription tensor, int size, void* addr, size_t offset);
NativeHandleWrapper(native_handle_t* handle, TensorDescription tensor, int size,
void* bufAddr, size_t offset, unsigned int fdFlag);
NativeHandleWrapper() : mFdFlag(0), mSize(0), mAddr(0), mHandle(nullptr), mOffset(0) {}
~NativeHandleWrapper();
private:
unsigned int mFdFlag;
unsigned int mSize;
void* mAddr;
native_handle_t* mHandle;
TensorDescription mTensor;
size_t mOffset;
};
class HLAI_API_EXPORT ModelBufferWrapper : public RefBase
{
public:
static sp<ModelBufferWrapper> createFromModelFile(string modelName, string modelPath,
AiDevPerf perf = AiDevPerf::DEV_HIGH_PROFILE);
static sp<ModelBufferWrapper> createFromModelBuf(string modelName, const void* modelBuf, int size,
AiDevPerf perf = AiDevPerf::DEV_HIGH_PROFILE);
ModelBuffer getModelBuf() const;
protected:
ModelBufferWrapper(string& modelName, native_handle_t* handle, uint32_t size, AiDevPerf perf);
~ModelBufferWrapper();
private:
static long GetFileLength(FILE* fp);
private:
ModelBuffer mModelBuf;
};
}; // namespace ai
#endif // __AI_MODEL_MANAGER_H__