You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
183 lines
7.9 KiB
183 lines
7.9 KiB
/*
|
|
* Copyright (C) 2021 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H
|
|
#define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H
|
|
|
|
#include <android/binder_auto_utils.h>
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "AidlBufferTracker.h"
|
|
#include "AidlHalInterfaces.h"
|
|
#include "CpuExecutor.h"
|
|
#include "NeuralNetworks.h"
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
namespace sample_driver {
|
|
|
|
// Manages the data buffer for an operand.
|
|
class SampleBuffer : public aidl_hal::BnBuffer {
|
|
public:
|
|
SampleBuffer(std::shared_ptr<AidlManagedBuffer> buffer,
|
|
std::unique_ptr<AidlBufferTracker::Token> token)
|
|
: kBuffer(std::move(buffer)), kToken(std::move(token)) {
|
|
CHECK(kBuffer != nullptr);
|
|
CHECK(kToken != nullptr);
|
|
}
|
|
ndk::ScopedAStatus copyFrom(const aidl_hal::Memory& src,
|
|
const std::vector<int32_t>& dimensions) override;
|
|
ndk::ScopedAStatus copyTo(const aidl_hal::Memory& dst) override;
|
|
|
|
private:
|
|
const std::shared_ptr<AidlManagedBuffer> kBuffer;
|
|
const std::unique_ptr<AidlBufferTracker::Token> kToken;
|
|
};
|
|
|
|
// Base class used to create sample drivers for the NN HAL. This class
|
|
// provides some implementation of the more common functions.
|
|
//
|
|
// Since these drivers simulate hardware, they must run the computations
|
|
// on the CPU. An actual driver would not do that.
|
|
class SampleDriver : public aidl_hal::BnDevice {
|
|
public:
|
|
SampleDriver(const char* name,
|
|
const IOperationResolver* operationResolver = BuiltinOperationResolver::get())
|
|
: mName(name),
|
|
mOperationResolver(operationResolver),
|
|
mBufferTracker(AidlBufferTracker::create()) {
|
|
android::nn::initVLogMask();
|
|
}
|
|
ndk::ScopedAStatus allocate(const aidl_hal::BufferDesc& desc,
|
|
const std::vector<aidl_hal::IPreparedModelParcel>& preparedModels,
|
|
const std::vector<aidl_hal::BufferRole>& inputRoles,
|
|
const std::vector<aidl_hal::BufferRole>& outputRoles,
|
|
aidl_hal::DeviceBuffer* buffer) override;
|
|
ndk::ScopedAStatus getNumberOfCacheFilesNeeded(
|
|
aidl_hal::NumberOfCacheFiles* numberOfCacheFiles) override;
|
|
ndk::ScopedAStatus getSupportedExtensions(
|
|
std::vector<aidl_hal::Extension>* extensions) override;
|
|
ndk::ScopedAStatus getType(aidl_hal::DeviceType* deviceType) override;
|
|
ndk::ScopedAStatus getVersionString(std::string* version) override;
|
|
ndk::ScopedAStatus prepareModel(
|
|
const aidl_hal::Model& model, aidl_hal::ExecutionPreference preference,
|
|
aidl_hal::Priority priority, int64_t deadlineNs,
|
|
const std::vector<ndk::ScopedFileDescriptor>& modelCache,
|
|
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
|
|
const std::vector<uint8_t>& token,
|
|
const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback) override;
|
|
ndk::ScopedAStatus prepareModelFromCache(
|
|
int64_t deadlineNs, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
|
|
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
|
|
const std::vector<uint8_t>& token,
|
|
const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback) override;
|
|
|
|
// Starts and runs the driver service. Typically called from main().
|
|
// This will return only once the service shuts down.
|
|
int run();
|
|
|
|
CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); }
|
|
const std::shared_ptr<AidlBufferTracker>& getBufferTracker() const { return mBufferTracker; }
|
|
|
|
protected:
|
|
std::string mName;
|
|
const IOperationResolver* mOperationResolver;
|
|
const std::shared_ptr<AidlBufferTracker> mBufferTracker;
|
|
};
|
|
|
|
class SamplePreparedModel : public aidl_hal::BnPreparedModel {
|
|
public:
|
|
SamplePreparedModel(aidl_hal::Model&& model, const SampleDriver* driver,
|
|
aidl_hal::ExecutionPreference preference, uid_t userId,
|
|
aidl_hal::Priority priority)
|
|
: mModel(std::move(model)),
|
|
mDriver(driver),
|
|
kPreference(preference),
|
|
kUserId(userId),
|
|
kPriority(priority) {
|
|
(void)kUserId;
|
|
(void)kPriority;
|
|
}
|
|
bool initialize();
|
|
ndk::ScopedAStatus executeSynchronously(const aidl_hal::Request& request, bool measureTiming,
|
|
int64_t deadlineNs, int64_t loopTimeoutDurationNs,
|
|
aidl_hal::ExecutionResult* executionResult) override;
|
|
ndk::ScopedAStatus executeFenced(const aidl_hal::Request& request,
|
|
const std::vector<ndk::ScopedFileDescriptor>& waitFor,
|
|
bool measureTiming, int64_t deadlineNs,
|
|
int64_t loopTimeoutDurationNs, int64_t durationNs,
|
|
aidl_hal::FencedExecutionResult* executionResult) override;
|
|
ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<aidl_hal::IBurst>* burst) override;
|
|
const aidl_hal::Model* getModel() const { return &mModel; }
|
|
|
|
protected:
|
|
aidl_hal::Model mModel;
|
|
const SampleDriver* mDriver;
|
|
std::vector<RunTimePoolInfo> mPoolInfos;
|
|
const aidl_hal::ExecutionPreference kPreference;
|
|
const uid_t kUserId;
|
|
const aidl_hal::Priority kPriority;
|
|
};
|
|
|
|
class SampleFencedExecutionCallback : public aidl_hal::BnFencedExecutionCallback {
|
|
public:
|
|
SampleFencedExecutionCallback(aidl_hal::Timing timingSinceLaunch,
|
|
aidl_hal::Timing timingAfterFence, aidl_hal::ErrorStatus error)
|
|
: kTimingSinceLaunch(timingSinceLaunch),
|
|
kTimingAfterFence(timingAfterFence),
|
|
kErrorStatus(error) {}
|
|
ndk::ScopedAStatus getExecutionInfo(aidl_hal::Timing* timingLaunched,
|
|
aidl_hal::Timing* timingFenced,
|
|
aidl_hal::ErrorStatus* errorStatus) override {
|
|
*timingLaunched = kTimingSinceLaunch;
|
|
*timingFenced = kTimingAfterFence;
|
|
*errorStatus = kErrorStatus;
|
|
return ndk::ScopedAStatus::ok();
|
|
}
|
|
|
|
private:
|
|
const aidl_hal::Timing kTimingSinceLaunch;
|
|
const aidl_hal::Timing kTimingAfterFence;
|
|
const aidl_hal::ErrorStatus kErrorStatus;
|
|
};
|
|
|
|
class SampleBurst : public aidl_hal::BnBurst {
|
|
public:
|
|
// Precondition: preparedModel != nullptr
|
|
explicit SampleBurst(std::shared_ptr<SamplePreparedModel> preparedModel);
|
|
|
|
ndk::ScopedAStatus executeSynchronously(const aidl_hal::Request& request,
|
|
const std::vector<int64_t>& memoryIdentifierTokens,
|
|
bool measureTiming, int64_t deadlineNs,
|
|
int64_t loopTimeoutDurationNs,
|
|
aidl_hal::ExecutionResult* executionResult) override;
|
|
ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override;
|
|
|
|
protected:
|
|
std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
|
|
const std::shared_ptr<SamplePreparedModel> kPreparedModel;
|
|
};
|
|
|
|
} // namespace sample_driver
|
|
} // namespace nn
|
|
} // namespace android
|
|
|
|
#endif // ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H
|