You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.3 KiB
101 lines
3.3 KiB
/*
|
|
* Copyright (C) 2021 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "CanonicalBuffer.h"
|
|
|
|
#include <android-base/logging.h>
|
|
#include <nnapi/IPreparedModel.h>
|
|
#include <nnapi/Result.h>
|
|
#include <nnapi/Types.h>
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <utility>
|
|
|
|
namespace android::nn::sample {
|
|
namespace {
|
|
|
|
void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
|
|
CHECK(srcPool.getBuffer() != nullptr);
|
|
CHECK(dstPool.getBuffer() != nullptr);
|
|
CHECK(srcPool.getSize() == dstPool.getSize());
|
|
std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
|
|
dstPool.flush();
|
|
}
|
|
|
|
GeneralResult<void> copyFromInternal(const SharedMemory& src, const Dimensions& dimensions,
|
|
const std::shared_ptr<ManagedBuffer>& bufferWrapper) {
|
|
CHECK(bufferWrapper != nullptr);
|
|
const auto srcPool = RunTimePoolInfo::createFromMemory(src);
|
|
if (!srcPool.has_value()) {
|
|
return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
|
|
<< "SampleBuffer::copyFrom -- unable to map src memory.";
|
|
}
|
|
const ErrorStatus validationStatus =
|
|
bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize());
|
|
if (validationStatus != ErrorStatus::NONE) {
|
|
return NN_ERROR(validationStatus);
|
|
}
|
|
const auto dstPool = bufferWrapper->createRunTimePoolInfo();
|
|
copyRunTimePoolInfos(srcPool.value(), dstPool);
|
|
|
|
return {};
|
|
}
|
|
|
|
} // namespace
|
|
|
|
Buffer::Buffer(std::shared_ptr<ManagedBuffer> buffer, std::unique_ptr<BufferTracker::Token> token)
|
|
: kBuffer(std::move(buffer)), kToken(std::move(token)) {
|
|
CHECK(kBuffer != nullptr);
|
|
CHECK(kToken != nullptr);
|
|
}
|
|
|
|
Request::MemoryDomainToken Buffer::getToken() const {
|
|
return Request::MemoryDomainToken{kToken->get()};
|
|
}
|
|
|
|
GeneralResult<void> Buffer::copyTo(const SharedMemory& dst) const {
|
|
const auto dstPool = RunTimePoolInfo::createFromMemory(dst);
|
|
if (!dstPool.has_value()) {
|
|
return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
|
|
<< "SampleBuffer::copyTo -- unable to map dst memory.";
|
|
}
|
|
|
|
const ErrorStatus validationStatus = kBuffer->validateCopyTo(dstPool->getSize());
|
|
if (validationStatus != ErrorStatus::NONE) {
|
|
return NN_ERROR(validationStatus);
|
|
}
|
|
|
|
const auto srcPool = kBuffer->createRunTimePoolInfo();
|
|
copyRunTimePoolInfos(srcPool, dstPool.value());
|
|
|
|
return {};
|
|
}
|
|
|
|
GeneralResult<void> Buffer::copyFrom(const SharedMemory& src, const Dimensions& dimensions) const {
|
|
if (const auto result = copyFromInternal(src, dimensions, kBuffer); !result.ok()) {
|
|
kBuffer->setInitialized(false);
|
|
NN_TRY(result);
|
|
}
|
|
|
|
kBuffer->updateDimensions(dimensions);
|
|
kBuffer->setInitialized(true);
|
|
|
|
return {};
|
|
}
|
|
|
|
} // namespace android::nn::sample
|