You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
6.9 KiB
219 lines
6.9 KiB
/* Copyright 2019 Google LLC. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "ruy/thread_pool.h"
|
|
|
|
#include <atomic>
|
|
#include <chrono> // NOLINT(build/c++11)
|
|
#include <condition_variable> // NOLINT(build/c++11)
|
|
#include <cstdint>
|
|
#include <cstdlib>
|
|
#include <memory>
|
|
#include <mutex> // NOLINT(build/c++11)
|
|
#include <thread> // NOLINT(build/c++11)
|
|
|
|
#include "ruy/check_macros.h"
|
|
#include "ruy/trace.h"
|
|
#include "ruy/wait.h"
|
|
|
|
namespace ruy {
|
|
|
|
// A worker thread.
|
|
class Thread {
|
|
public:
|
|
enum class State {
|
|
Startup, // The initial state before the thread main loop runs.
|
|
Ready, // Is not working, has not yet received new work to do.
|
|
HasWork, // Has work to do.
|
|
ExitAsSoonAsPossible // Should exit at earliest convenience.
|
|
};
|
|
|
|
explicit Thread(BlockingCounter* counter_to_decrement_when_ready,
|
|
Duration spin_duration)
|
|
: task_(nullptr),
|
|
state_(State::Startup),
|
|
counter_to_decrement_when_ready_(counter_to_decrement_when_ready),
|
|
spin_duration_(spin_duration) {
|
|
thread_.reset(new std::thread(ThreadFunc, this));
|
|
}
|
|
|
|
~Thread() {
|
|
ChangeState(State::ExitAsSoonAsPossible);
|
|
thread_->join();
|
|
}
|
|
|
|
// Changes State; may be called from either the worker thread
|
|
// or the master thread; however, not all state transitions are legal,
|
|
// which is guarded by assertions.
|
|
//
|
|
// The Task argument is to be used only with new_state==HasWork.
|
|
// It specifies the Task being handed to this Thread.
|
|
void ChangeState(State new_state, Task* task = nullptr) {
|
|
state_mutex_.lock();
|
|
State old_state = state_.load(std::memory_order_relaxed);
|
|
RUY_DCHECK_NE(old_state, new_state);
|
|
switch (old_state) {
|
|
case State::Startup:
|
|
RUY_DCHECK_EQ(new_state, State::Ready);
|
|
break;
|
|
case State::Ready:
|
|
RUY_DCHECK(new_state == State::HasWork ||
|
|
new_state == State::ExitAsSoonAsPossible);
|
|
break;
|
|
case State::HasWork:
|
|
RUY_DCHECK(new_state == State::Ready ||
|
|
new_state == State::ExitAsSoonAsPossible);
|
|
break;
|
|
default:
|
|
abort();
|
|
}
|
|
switch (new_state) {
|
|
case State::Ready:
|
|
if (task_) {
|
|
// Doing work is part of reverting to 'ready' state.
|
|
task_->Run();
|
|
task_ = nullptr;
|
|
}
|
|
break;
|
|
case State::HasWork:
|
|
RUY_DCHECK(!task_);
|
|
task_ = task;
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
state_.store(new_state, std::memory_order_relaxed);
|
|
state_cond_.notify_all();
|
|
state_mutex_.unlock();
|
|
if (new_state == State::Ready) {
|
|
counter_to_decrement_when_ready_->DecrementCount();
|
|
}
|
|
}
|
|
|
|
static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
|
|
|
|
// Called by the master thead to give this thread work to do.
|
|
void StartWork(Task* task) { ChangeState(State::HasWork, task); }
|
|
|
|
private:
|
|
// Thread entry point.
|
|
void ThreadFuncImpl() {
|
|
RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
|
|
ChangeState(State::Ready);
|
|
|
|
// Thread main loop
|
|
while (true) {
|
|
RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
|
|
// In the 'Ready' state, we have nothing to do but to wait until
|
|
// we switch to another state.
|
|
const auto& condition = [this]() {
|
|
return state_.load(std::memory_order_acquire) != State::Ready;
|
|
};
|
|
RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
|
|
Wait(condition, spin_duration_, &state_cond_, &state_mutex_);
|
|
|
|
// Act on new state.
|
|
switch (state_.load(std::memory_order_acquire)) {
|
|
case State::HasWork: {
|
|
RUY_TRACE_SCOPE_NAME("Worker thread task");
|
|
// Got work to do! So do it, and then revert to 'Ready' state.
|
|
ChangeState(State::Ready);
|
|
break;
|
|
}
|
|
case State::ExitAsSoonAsPossible:
|
|
return;
|
|
default:
|
|
abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
// The underlying thread.
|
|
std::unique_ptr<std::thread> thread_;
|
|
|
|
// The task to be worked on.
|
|
Task* task_;
|
|
|
|
// The condition variable and mutex guarding state changes.
|
|
std::condition_variable state_cond_;
|
|
std::mutex state_mutex_;
|
|
|
|
// The state enum tells if we're currently working, waiting for work, etc.
|
|
// Its concurrent accesses by the thread and main threads are guarded by
|
|
// state_mutex_, and can thus use memory_order_relaxed. This still needs
|
|
// to be a std::atomic because we use WaitForVariableChange.
|
|
std::atomic<State> state_;
|
|
|
|
// pointer to the master's thread BlockingCounter object, to notify the
|
|
// master thread of when this thread switches to the 'Ready' state.
|
|
BlockingCounter* const counter_to_decrement_when_ready_;
|
|
|
|
// See ThreadPool::spin_duration_.
|
|
const Duration spin_duration_;
|
|
};
|
|
|
|
void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
|
|
RUY_TRACE_SCOPE_NAME("ThreadPool::Execute");
|
|
RUY_DCHECK_GE(task_count, 1);
|
|
|
|
// Case of 1 thread: just run the single task on the current thread.
|
|
if (task_count == 1) {
|
|
(tasks + 0)->Run();
|
|
return;
|
|
}
|
|
|
|
// Task #0 will be run on the current thread.
|
|
CreateThreads(task_count - 1);
|
|
counter_to_decrement_when_ready_.Reset(task_count - 1);
|
|
for (int i = 1; i < task_count; i++) {
|
|
RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
|
|
auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
|
|
threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
|
|
}
|
|
|
|
RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD);
|
|
// Execute task #0 immediately on the current thread.
|
|
(tasks + 0)->Run();
|
|
|
|
RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
|
|
// Wait for the threads submitted above to finish.
|
|
counter_to_decrement_when_ready_.Wait(spin_duration_);
|
|
}
|
|
|
|
// Ensures that the pool has at least the given count of threads.
|
|
// If any new thread has to be created, this function waits for it to
|
|
// be ready.
|
|
void ThreadPool::CreateThreads(int threads_count) {
|
|
RUY_DCHECK_GE(threads_count, 0);
|
|
unsigned int unsigned_threads_count = threads_count;
|
|
if (threads_.size() >= unsigned_threads_count) {
|
|
return;
|
|
}
|
|
counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
|
|
while (threads_.size() < unsigned_threads_count) {
|
|
threads_.push_back(
|
|
new Thread(&counter_to_decrement_when_ready_, spin_duration_));
|
|
}
|
|
counter_to_decrement_when_ready_.Wait(spin_duration_);
|
|
}
|
|
|
|
ThreadPool::~ThreadPool() {
|
|
for (auto w : threads_) {
|
|
delete w;
|
|
}
|
|
}
|
|
|
|
} // end namespace ruy
|