You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
4.4 KiB
122 lines
4.4 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2020 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H
|
||
|
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H
|
||
|
|
||
|
#include <android-base/macros.h>
|
||
|
|
||
|
#include <map>
|
||
|
#include <memory>
|
||
|
#include <mutex>
|
||
|
#include <set>
|
||
|
#include <stack>
|
||
|
#include <utility>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "CpuExecutor.h"
|
||
|
#include "HalInterfaces.h"
|
||
|
#include "Utils.h"
|
||
|
#include "ValidateHal.h"
|
||
|
|
||
|
namespace android::nn {
|
||
|
|
||
|
// This class manages a CPU buffer allocated on heap and provides validation methods.
|
||
|
class HalManagedBuffer {
|
||
|
public:
|
||
|
static std::shared_ptr<HalManagedBuffer> create(uint32_t size,
|
||
|
std::set<HalPreparedModelRole> roles,
|
||
|
const Operand& operand);
|
||
|
|
||
|
// Prefer HalManagedBuffer::create.
|
||
|
HalManagedBuffer(std::unique_ptr<uint8_t[]> buffer, uint32_t size,
|
||
|
std::set<HalPreparedModelRole> roles, const Operand& operand);
|
||
|
|
||
|
RunTimePoolInfo createRunTimePoolInfo() const {
|
||
|
return RunTimePoolInfo::createFromExistingBuffer(kBuffer.get(), kSize);
|
||
|
}
|
||
|
|
||
|
// "poolIndex" is the index of this buffer in the request.pools.
|
||
|
ErrorStatus validateRequest(uint32_t poolIndex, const Request& request,
|
||
|
const V1_3::IPreparedModel* preparedModel) const;
|
||
|
|
||
|
// "size" is the byte size of the Memory provided to the copyFrom or copyTo method.
|
||
|
ErrorStatus validateCopyFrom(const std::vector<uint32_t>& dimensions, uint32_t size) const;
|
||
|
ErrorStatus validateCopyTo(uint32_t size) const;
|
||
|
|
||
|
bool updateDimensions(const std::vector<uint32_t>& dimensions);
|
||
|
void setInitialized(bool initialized);
|
||
|
|
||
|
private:
|
||
|
mutable std::mutex mMutex;
|
||
|
const std::unique_ptr<uint8_t[]> kBuffer;
|
||
|
const uint32_t kSize;
|
||
|
const std::set<HalPreparedModelRole> kRoles;
|
||
|
const OperandType kOperandType;
|
||
|
const std::vector<uint32_t> kInitialDimensions;
|
||
|
std::vector<uint32_t> mUpdatedDimensions;
|
||
|
bool mInitialized = false;
|
||
|
};
|
||
|
|
||
|
// Keep track of all HalManagedBuffers and assign each with a unique token.
|
||
|
class HalBufferTracker : public std::enable_shared_from_this<HalBufferTracker> {
|
||
|
DISALLOW_COPY_AND_ASSIGN(HalBufferTracker);
|
||
|
|
||
|
public:
|
||
|
// A RAII class to help manage the lifetime of the token.
|
||
|
// It is only supposed to be constructed in HalBufferTracker::add.
|
||
|
class Token {
|
||
|
DISALLOW_COPY_AND_ASSIGN(Token);
|
||
|
|
||
|
public:
|
||
|
Token(uint32_t token, std::shared_ptr<HalBufferTracker> tracker)
|
||
|
: kToken(token), kHalBufferTracker(std::move(tracker)) {}
|
||
|
~Token() { kHalBufferTracker->free(kToken); }
|
||
|
uint32_t get() const { return kToken; }
|
||
|
|
||
|
private:
|
||
|
const uint32_t kToken;
|
||
|
const std::shared_ptr<HalBufferTracker> kHalBufferTracker;
|
||
|
};
|
||
|
|
||
|
// The factory of HalBufferTracker. This ensures that the HalBufferTracker is always managed by
|
||
|
// a shared_ptr.
|
||
|
static std::shared_ptr<HalBufferTracker> create() {
|
||
|
return std::make_shared<HalBufferTracker>();
|
||
|
}
|
||
|
|
||
|
// Prefer HalBufferTracker::create.
|
||
|
HalBufferTracker() : mTokenToBuffers(1) {}
|
||
|
|
||
|
std::unique_ptr<Token> add(std::shared_ptr<HalManagedBuffer> buffer);
|
||
|
std::shared_ptr<HalManagedBuffer> get(uint32_t token) const;
|
||
|
|
||
|
private:
|
||
|
void free(uint32_t token);
|
||
|
|
||
|
mutable std::mutex mMutex;
|
||
|
std::stack<uint32_t, std::vector<uint32_t>> mFreeTokens;
|
||
|
|
||
|
// Since the tokens are allocated in a non-sparse way, we use a vector to represent the mapping.
|
||
|
// The index of the vector is the token. When the token gets freed, the corresponding entry is
|
||
|
// set to nullptr. mTokenToBuffers[0] is always set to nullptr because 0 is an invalid token.
|
||
|
std::vector<std::shared_ptr<HalManagedBuffer>> mTokenToBuffers;
|
||
|
};
|
||
|
|
||
|
} // namespace android::nn
|
||
|
|
||
|
#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_HAL_BUFFER_TRACKER_H
|