/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
#define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H

#include <CpuExecutor.h>
#include <HalBufferTracker.h>
#include <HalInterfaces.h>
#include <hwbinder/IPCThreadState.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "NeuralNetworks.h"

namespace android {
namespace nn {
namespace sample_driver {

using hardware::MQDescriptorSync;

// Manages the data buffer for an operand.
class SampleBuffer : public V1_3::IBuffer {
   public:
    SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer,
                 std::unique_ptr<HalBufferTracker::Token> token)
        : kBuffer(std::move(buffer)), kToken(std::move(token)) {
        CHECK(kBuffer != nullptr);
        CHECK(kToken != nullptr);
    }
    hardware::Return<V1_3::ErrorStatus> copyTo(const hardware::hidl_memory& dst) override;
    hardware::Return<V1_3::ErrorStatus> copyFrom(
            const hardware::hidl_memory& src,
            const hardware::hidl_vec<uint32_t>& dimensions) override;

   private:
    const std::shared_ptr<HalManagedBuffer> kBuffer;
    const std::unique_ptr<HalBufferTracker::Token> kToken;
};

// Base class used to create sample drivers for the NN HAL.  This class
// provides some implementation of the more common functions.
//
// Since these drivers simulate hardware, they must run the computations
// on the CPU.  An actual driver would not do that.
class SampleDriver : public V1_3::IDevice {
   public:
    SampleDriver(const char* name,
                 const IOperationResolver* operationResolver = BuiltinOperationResolver::get())
        : mName(name),
          mOperationResolver(operationResolver),
          mHalBufferTracker(HalBufferTracker::create()) {
        android::nn::initVLogMask();
    }
    hardware::Return<void> getCapabilities(getCapabilities_cb cb) override;
    hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override;
    hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
    hardware::Return<void> getVersionString(getVersionString_cb cb) override;
    hardware::Return<void> getType(getType_cb cb) override;
    hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb) override;
    hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
                                                  getSupportedOperations_cb cb) override;
    hardware::Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
                                                      getSupportedOperations_1_1_cb cb) override;
    hardware::Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
                                                      getSupportedOperations_1_2_cb cb) override;
    hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override;
    hardware::Return<V1_0::ErrorStatus> prepareModel(
            const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) override;
    hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
            const V1_1::Model& model, V1_1::ExecutionPreference preference,
            const sp<V1_0::IPreparedModelCallback>& callback) override;
    hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
            const V1_2::Model& model, V1_1::ExecutionPreference preference,
            const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
            const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
            const sp<V1_2::IPreparedModelCallback>& callback) override;
    hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
            const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
            const V1_3::OptionalTimePoint& deadline,
            const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
            const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
            const sp<V1_3::IPreparedModelCallback>& callback) override;
    hardware::Return<V1_0::ErrorStatus> prepareModelFromCache(
            const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
            const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
            const sp<V1_2::IPreparedModelCallback>& callback) override;
    hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
            const V1_3::OptionalTimePoint& deadline,
            const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
            const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
            const sp<V1_3::IPreparedModelCallback>& callback) override;
    hardware::Return<V1_0::DeviceStatus> getStatus() override;
    hardware::Return<void> allocate(
            const V1_3::BufferDesc& desc,
            const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
            const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
            const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) override;

    // Starts and runs the driver service.  Typically called from main().
    // This will return only once the service shuts down.
    int run();

    CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); }
    const std::shared_ptr<HalBufferTracker>& getHalBufferTracker() const {
        return mHalBufferTracker;
    }

   protected:
    std::string mName;
    const IOperationResolver* mOperationResolver;
    const std::shared_ptr<HalBufferTracker> mHalBufferTracker;
};

class SamplePreparedModel : public V1_3::IPreparedModel {
   public:
    SamplePreparedModel(const V1_3::Model& model, const SampleDriver* driver,
                        V1_1::ExecutionPreference preference, uid_t userId, V1_3::Priority priority)
        : mModel(model),
          mDriver(driver),
          kPreference(preference),
          kUserId(userId),
          kPriority(priority) {
        (void)kUserId;
        (void)kPriority;
    }
    bool initialize();
    hardware::Return<V1_0::ErrorStatus> execute(
            const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override;
    hardware::Return<V1_0::ErrorStatus> execute_1_2(
            const V1_0::Request& request, V1_2::MeasureTiming measure,
            const sp<V1_2::IExecutionCallback>& callback) override;
    hardware::Return<V1_3::ErrorStatus> execute_1_3(
            const V1_3::Request& request, V1_2::MeasureTiming measure,
            const V1_3::OptionalTimePoint& deadline,
            const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
            const sp<V1_3::IExecutionCallback>& callback) override;
    hardware::Return<void> executeSynchronously(const V1_0::Request& request,
                                                V1_2::MeasureTiming measure,
                                                executeSynchronously_cb cb) override;
    hardware::Return<void> executeSynchronously_1_3(
            const V1_3::Request& request, V1_2::MeasureTiming measure,
            const V1_3::OptionalTimePoint& deadline,
            const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
            executeSynchronously_1_3_cb cb) override;
    hardware::Return<void> configureExecutionBurst(
            const sp<V1_2::IBurstCallback>& callback,
            const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
            const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
            configureExecutionBurst_cb cb) override;
    hardware::Return<void> executeFenced(const V1_3::Request& request,
                                         const hardware::hidl_vec<hardware::hidl_handle>& wait_for,
                                         V1_2::MeasureTiming measure,
                                         const V1_3::OptionalTimePoint& deadline,
                                         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
                                         const V1_3::OptionalTimeoutDuration& duration,
                                         executeFenced_cb callback) override;
    const V1_3::Model* getModel() const { return &mModel; }

   protected:
    V1_3::Model mModel;
    const SampleDriver* mDriver;
    std::vector<RunTimePoolInfo> mPoolInfos;
    const V1_1::ExecutionPreference kPreference;
    const uid_t kUserId;
    const V1_3::Priority kPriority;
};

class SampleFencedExecutionCallback : public V1_3::IFencedExecutionCallback {
   public:
    SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch, V1_2::Timing timingAfterFence,
                                  V1_3::ErrorStatus error)
        : kTimingSinceLaunch(timingSinceLaunch),
          kTimingAfterFence(timingAfterFence),
          kErrorStatus(error) {}
    hardware::Return<void> getExecutionInfo(getExecutionInfo_cb callback) override {
        callback(kErrorStatus, kTimingSinceLaunch, kTimingAfterFence);
        return hardware::Void();
    }

   private:
    const V1_2::Timing kTimingSinceLaunch;
    const V1_2::Timing kTimingAfterFence;
    const V1_3::ErrorStatus kErrorStatus;
};

}  // namespace sample_driver
}  // namespace nn
}  // namespace android

#endif  // ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H