You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
5.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/**
* Copyright (c) Hisilicon Technologies Co., Ltd.. 2020-2020. All rights reserved.
*
* @file HLAIModelManagerService.h
*
*
*/
#ifndef __AI_MODEL_MANGER_SERVICE_H__
#define __AI_MODEL_MANGER_SERVICE_H__
#include <string>
#include <vector>
#include <map>
#include <mutex>
#include "HLAIModelManagerType.h"
namespace hlai {
class AiModelMngerClientImpl;
class AiModelMngerClient;
class AiModelBuilderImpl;
class AiModelBuilder {
public:
explicit AiModelBuilder(std::shared_ptr<AiModelMngerClient> client = nullptr);
virtual ~AiModelBuilder();
/*
* @brief 从文件读取OM离线模型proto信息
* @param [in] path, 模型文件路径
* @return MemBuffer * proto信息存储地址
* @return nullptr 获取失败
*/
MemBuffer* ReadBinaryProto(const std::string path);
/*
* @brief 从内存读取OM离线模型proto信息
* @param [in] data, OM离线模型内存地址
* @param [in] size, OM离线模型内存存储大小
* @return MemBuffer * proto信息存储地址
* @return nullptr 获取失败
*/
MemBuffer* ReadBinaryProto(void* data, uint32_t size);
/*
* @brief 为输入OM离线模型用户内存buffer创建通用MemBuffer
* @param [in] data, 模型用户内存地址
* @param [in] size, 模型内存存储大小
* @return MemBuffer * proto信息存储地址
* @return nullptr 获取失败
*/
MemBuffer* InputMemBufferCreate(void *data, uint32_t size);
/*
* @brief 为输入OM离线模型从文件创建MemBuffer
* @param [in] path 文件路径
* @return MemBuffer * 创建的输入MemBuffer内存指针
* @return nullptr 创建失败
*/
MemBuffer* InputMemBufferCreate(const std::string path);
/*
* @brief 为在线编译输出模型创建MemBuffer
* @param [in] framework 模型平台类型
* @param [in] pinputMemBuffer 输入的OM离线模型buffer
* @return MemBuffer * 创建的输出模型MemBuffer内存指针
* @return nullptr 创建失败
*/
MemBuffer* OutputMemBufferCreate(const int32_t framework, const std::vector<MemBuffer *> &pinputMemBuffer);
/*
* @brief 注销MemBuffer内存通过上述MemBuffer申请的内存最终都需要由此接口进行释放
* @param [in] membuf, 创建的MemBuffer内存
* @return void
*/
void MemBufferDestroy(MemBuffer *membuf);
/*
* @brief 将模型buffer导出到文件
* @param [in] membuf, 存储离线模型信息内存指针
* @param [in] pbuildSize, 离线模型大小
* @param [in] pbuildPath, 离线模型导出文件存储路径
* @return AIStatus::AI_SUCCESS 导出成功
* @return Others 导出失败
*/
AIStatus MemBufferExportFile(MemBuffer *membuf, const uint32_t pbuildSize, const std::string pbuildPath);
private:
std::shared_ptr<AiModelBuilderImpl> builderImpl_;
};
class AiModelMngerClient {
public:
AiModelMngerClient();
virtual ~AiModelMngerClient();
/*
* @brief 初始化接口
* @param [in] listener 监听接口, 设置为nullptr时为同步调用, 否则为异步
* @return AIStatus::AI_SUCCESS 成功
* @return Others 失败
*/
AIStatus Init(std::shared_ptr<AiModelManagerClientListener> listener);
/*
* @brief 加载模型
* @param [in] pmodelDesc 模型信息
* @param [in] piStamp 异步返回标识,基于该标识和模型名称做回调索引
* @return AIStatus::AI_SUCCESS 成功
* @return AIStatus::AI_INVALID_API 失败表示设备不支持NPU
* @return Others 失败
*/
AIStatus Load(std::vector<std::shared_ptr<AiModelDescription>> &pmodelDesc);
/*
* @brief 模型处理接口, 运行加载模型的模型推理
* @param [in] context, 模型运行上下文, 必须带model_name字段
* @param [in] pinputTensor, 模型输入节点tensor信息
* @param [in] poutputTensor, 模型输出节点tensor信息
* @param [in] timeout, 推理超时时间
* @param [in] piStamp 异步返回标识,基于该标识和模型名称做回调索引
* @return AIStatus::AI_SUCCESS 成功
* @return Others 失败
*/
AIStatus Process(AiContext &context, std::vector<std::shared_ptr<AiTensor>> &pinputTensor,
std::vector<std::shared_ptr<AiTensor>> &poutputTensor, uint32_t timeout, int32_t &piStamp);
/*
* @brief 模型兼容性检查接口
* @param [in] pmodelDesc, 模型描述
* @param [out] pisModelCompatibility, 兼容性检查标识
* @return AIStatus::AI_SUCCESS 兼容性检查通过
* @return Others 兼容性检查失败
*/
AIStatus CheckModelCompatibility(AiModelDescription &pmodelDesc, bool &pisModelCompatibility);
/*
* @brief 获取模型输入输出tensor信息
* @param [in] pmodelName, 模型名
* @param [out] pinputTensor 输出参数, 存储模型输入节点tensor信息
* @param [out] poutputTensor 输出参数, 存储模型输出节点tensor信息
* @return AIStatus::AI_SUCCESS 获取成功
* @return Others 获取失败
*/
AIStatus GetModelIOTensorDim(const std::string& pmodelName, std::vector<TensorDimension>& pinputTensor,
std::vector<TensorDimension>& poutputTensor);
/*
* @brief 获取DDK版本
* @return char * DDK版本
* @return nullptr 获取失败
*/
const char* GetVersion();
/*
* @brief 卸载模型
* @return AIStatus::AI_SUCCESS 卸载成功
* @return Others 卸载失败
*/
AIStatus UnLoadModel();
private:
friend class AiModelBuilderImpl;
std::shared_ptr<AiModelMngerClientImpl> clientImpl_;
};
} // end namespace hlai
#endif