You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
221 lines
7.2 KiB
221 lines
7.2 KiB
// Copyright 2015 The Chromium OS Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include <brillo/streams/stream_utils.h>
|
|
|
|
#include <algorithm>
|
|
#include <limits>
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <base/bind.h>
|
|
#include <brillo/message_loops/message_loop.h>
|
|
#include <brillo/streams/stream_errors.h>
|
|
|
|
namespace brillo {
|
|
namespace stream_utils {
|
|
|
|
namespace {
|
|
|
|
// Status of asynchronous CopyData operation.
|
|
struct CopyDataState {
|
|
brillo::StreamPtr in_stream;
|
|
brillo::StreamPtr out_stream;
|
|
std::vector<uint8_t> buffer;
|
|
uint64_t remaining_to_copy;
|
|
uint64_t size_copied;
|
|
CopyDataSuccessCallback success_callback;
|
|
CopyDataErrorCallback error_callback;
|
|
};
|
|
|
|
// Async CopyData I/O error callback.
|
|
void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
|
|
const brillo::Error* error) {
|
|
state->error_callback.Run(std::move(state->in_stream),
|
|
std::move(state->out_stream), error);
|
|
}
|
|
|
|
// Forward declaration.
|
|
void PerformRead(const std::shared_ptr<CopyDataState>& state);
|
|
|
|
// Callback from read operation for CopyData. Writes the read data to the output
|
|
// stream and invokes PerformRead when done to restart the copy cycle.
|
|
void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
|
|
if (size == 0) {
|
|
state->success_callback.Run(std::move(state->in_stream),
|
|
std::move(state->out_stream),
|
|
state->size_copied);
|
|
return;
|
|
}
|
|
state->size_copied += size;
|
|
CHECK_GE(state->remaining_to_copy, size);
|
|
state->remaining_to_copy -= size;
|
|
|
|
brillo::ErrorPtr error;
|
|
bool success = state->out_stream->WriteAllAsync(
|
|
state->buffer.data(), size, base::Bind(&PerformRead, state),
|
|
base::Bind(&OnCopyDataError, state), &error);
|
|
|
|
if (!success)
|
|
OnCopyDataError(state, error.get());
|
|
}
|
|
|
|
// Performs the read part of asynchronous CopyData operation. Reads the data
|
|
// from input stream and invokes PerformWrite when done to write the data to
|
|
// the output stream.
|
|
void PerformRead(const std::shared_ptr<CopyDataState>& state) {
|
|
brillo::ErrorPtr error;
|
|
const uint64_t buffer_size = state->buffer.size();
|
|
// |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
|
|
// also not overflow size_t, so the static_cast below is safe.
|
|
size_t size_to_read =
|
|
static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
|
|
if (size_to_read == 0)
|
|
return PerformWrite(state, 0); // Nothing more to read. Finish operation.
|
|
bool success = state->in_stream->ReadAsync(
|
|
state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
|
|
base::Bind(OnCopyDataError, state), &error);
|
|
|
|
if (!success)
|
|
OnCopyDataError(state, error.get());
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
bool ErrorStreamClosed(const base::Location& location,
|
|
ErrorPtr* error) {
|
|
Error::AddTo(error,
|
|
location,
|
|
errors::stream::kDomain,
|
|
errors::stream::kStreamClosed,
|
|
"Stream is closed");
|
|
return false;
|
|
}
|
|
|
|
bool ErrorOperationNotSupported(const base::Location& location,
|
|
ErrorPtr* error) {
|
|
Error::AddTo(error,
|
|
location,
|
|
errors::stream::kDomain,
|
|
errors::stream::kOperationNotSupported,
|
|
"Stream operation not supported");
|
|
return false;
|
|
}
|
|
|
|
bool ErrorReadPastEndOfStream(const base::Location& location,
|
|
ErrorPtr* error) {
|
|
Error::AddTo(error,
|
|
location,
|
|
errors::stream::kDomain,
|
|
errors::stream::kPartialData,
|
|
"Reading past the end of stream");
|
|
return false;
|
|
}
|
|
|
|
bool ErrorOperationTimeout(const base::Location& location,
|
|
ErrorPtr* error) {
|
|
Error::AddTo(error,
|
|
location,
|
|
errors::stream::kDomain,
|
|
errors::stream::kTimeout,
|
|
"Operation timed out");
|
|
return false;
|
|
}
|
|
|
|
bool CheckInt64Overflow(const base::Location& location,
|
|
uint64_t position,
|
|
int64_t offset,
|
|
ErrorPtr* error) {
|
|
if (offset < 0) {
|
|
// Subtracting the offset. Make sure we do not underflow.
|
|
uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
|
|
if (position >= unsigned_offset)
|
|
return true;
|
|
} else {
|
|
// Adding the offset. Make sure we do not overflow unsigned 64 bits first.
|
|
if (position <= std::numeric_limits<uint64_t>::max() - offset) {
|
|
// We definitely will not overflow the unsigned 64 bit integer.
|
|
// Now check that we end up within the limits of signed 64 bit integer.
|
|
uint64_t new_position = position + offset;
|
|
uint64_t max = std::numeric_limits<int64_t>::max();
|
|
if (new_position <= max)
|
|
return true;
|
|
}
|
|
}
|
|
Error::AddTo(error,
|
|
location,
|
|
errors::stream::kDomain,
|
|
errors::stream::kInvalidParameter,
|
|
"The stream offset value is out of range");
|
|
return false;
|
|
}
|
|
|
|
bool CalculateStreamPosition(const base::Location& location,
|
|
int64_t offset,
|
|
Stream::Whence whence,
|
|
uint64_t current_position,
|
|
uint64_t stream_size,
|
|
uint64_t* new_position,
|
|
ErrorPtr* error) {
|
|
uint64_t pos = 0;
|
|
switch (whence) {
|
|
case Stream::Whence::FROM_BEGIN:
|
|
pos = 0;
|
|
break;
|
|
|
|
case Stream::Whence::FROM_CURRENT:
|
|
pos = current_position;
|
|
break;
|
|
|
|
case Stream::Whence::FROM_END:
|
|
pos = stream_size;
|
|
break;
|
|
|
|
default:
|
|
Error::AddTo(error,
|
|
location,
|
|
errors::stream::kDomain,
|
|
errors::stream::kInvalidParameter,
|
|
"Invalid stream position whence");
|
|
return false;
|
|
}
|
|
|
|
if (!CheckInt64Overflow(location, pos, offset, error))
|
|
return false;
|
|
|
|
*new_position = static_cast<uint64_t>(pos + offset);
|
|
return true;
|
|
}
|
|
|
|
void CopyData(StreamPtr in_stream,
|
|
StreamPtr out_stream,
|
|
const CopyDataSuccessCallback& success_callback,
|
|
const CopyDataErrorCallback& error_callback) {
|
|
CopyData(std::move(in_stream), std::move(out_stream),
|
|
std::numeric_limits<uint64_t>::max(), 4096, success_callback,
|
|
error_callback);
|
|
}
|
|
|
|
void CopyData(StreamPtr in_stream,
|
|
StreamPtr out_stream,
|
|
uint64_t max_size_to_copy,
|
|
size_t buffer_size,
|
|
const CopyDataSuccessCallback& success_callback,
|
|
const CopyDataErrorCallback& error_callback) {
|
|
auto state = std::make_shared<CopyDataState>();
|
|
state->in_stream = std::move(in_stream);
|
|
state->out_stream = std::move(out_stream);
|
|
state->buffer.resize(buffer_size);
|
|
state->remaining_to_copy = max_size_to_copy;
|
|
state->size_copied = 0;
|
|
state->success_callback = success_callback;
|
|
state->error_callback = error_callback;
|
|
brillo::MessageLoop::current()->PostTask(FROM_HERE,
|
|
base::BindOnce(&PerformRead, state));
|
|
}
|
|
|
|
} // namespace stream_utils
|
|
} // namespace brillo
|