include/SpartaWorkQueue.h (272 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <algorithm>
#include <atomic>
#include <boost/optional/optional.hpp>
#include <boost/thread/thread.hpp>
#include <cassert>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <numeric>
#include <queue>
#include <random>
#include <utility>
#include "Arity.h"
namespace sparta {
namespace parallel {
/**
* Sparta uses the number of physical cores.
*/
static inline unsigned int default_num_threads() {
unsigned int threads = boost::thread::physical_concurrency();
return std::max(1u, threads);
}
} // namespace parallel
namespace workqueue_impl {
/**
* Creates a random ordering of which threads to visit. This prevents threads
* from being prematurely emptied (if everyone targets thread 0, for example)
*
* Each thread should empty its own queue first, so we explicitly set the
* thread's index as the first element of the list.
*/
inline std::vector<unsigned int> create_permutation(unsigned int num,
unsigned int thread_idx) {
std::vector<unsigned int> attempts(num);
std::iota(attempts.begin(), attempts.end(), 0);
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(
attempts.begin(), attempts.end(), std::default_random_engine(seed));
std::iter_swap(attempts.begin(),
std::find(attempts.begin(), attempts.end(), thread_idx));
return attempts;
}
class Semaphore {
public:
explicit Semaphore(size_t initial = 0u) : m_count(initial) {}
inline void give(size_t n = 1u) {
std::unique_lock<std::mutex> lock(m_mtx);
m_count += n;
if (n == 1) {
m_cv.notify_one();
} else {
m_cv.notify_all(); // A bit suboptimal, but easier than precise counting.
}
}
inline void take() {
std::unique_lock<std::mutex> lock(m_mtx);
while (m_count == 0) {
m_cv.wait(lock);
}
--m_count;
}
inline void take_all() {
std::unique_lock<std::mutex> lock(m_mtx);
m_count = 0;
}
private:
std::mutex m_mtx;
std::condition_variable m_cv;
size_t m_count;
};
struct StateCounters {
std::atomic_uint num_non_empty;
std::atomic_uint num_running;
const unsigned int num_all;
// Mutexes aren't move-able.
std::unique_ptr<Semaphore> waiter;
explicit StateCounters(unsigned int num)
: num_non_empty(0),
num_running(0),
num_all(num),
waiter(new Semaphore(0)) {}
StateCounters(StateCounters&& other)
: num_non_empty(other.num_non_empty.load()),
num_running(other.num_running.load()),
num_all(other.num_all),
waiter(std::move(other.waiter)) {}
};
} // namespace workqueue_impl
template <class Input, typename Executor>
class SpartaWorkQueue;
template <class Input>
class SpartaWorkerState final {
public:
SpartaWorkerState(size_t id, workqueue_impl::StateCounters* sc, bool can_push)
: m_id(id), m_state_counters(sc), m_can_push_task(can_push) {}
/*
* Add more items to the queue of the currently-running worker. When a
* SpartaWorkQueue is running, this should be used instead of
* SpartaWorkQueue::add_item() as the latter is not thread-safe.
*/
void push_task(Input task) {
assert(m_can_push_task);
std::lock_guard<std::mutex> guard(m_queue_mtx);
if (m_queue.empty()) {
++m_state_counters->num_non_empty;
}
if (m_state_counters->num_running < m_state_counters->num_all) {
m_state_counters->waiter->give(1u); // May consider waking all.
}
m_queue.push(task);
}
size_t worker_id() const { return m_id; }
void set_running(bool running) {
if (m_running && !running) {
assert(m_state_counters->num_running > 0);
--m_state_counters->num_running;
} else if (!m_running && running) {
++m_state_counters->num_running;
}
m_running = running;
};
private:
boost::optional<Input> pop_task(SpartaWorkerState<Input>* other) {
std::lock_guard<std::mutex> guard(m_queue_mtx);
if (!m_queue.empty()) {
other->set_running(true);
if (m_queue.size() == 1) {
assert(m_state_counters->num_non_empty > 0);
--m_state_counters->num_non_empty;
}
auto task = std::move(m_queue.front());
m_queue.pop();
return task;
}
return boost::none;
}
size_t m_id;
bool m_running{false};
std::queue<Input> m_queue;
std::mutex m_queue_mtx;
workqueue_impl::StateCounters* m_state_counters;
const bool m_can_push_task{false};
template <class, typename>
friend class SpartaWorkQueue;
};
template <class Input, typename Executor>
class SpartaWorkQueue {
private:
// Using templates for Executor to avoid the performance overhead of
// std::function
Executor m_executor;
std::vector<std::unique_ptr<SpartaWorkerState<Input>>> m_states;
const size_t m_num_threads{1};
size_t m_insert_idx{0};
workqueue_impl::StateCounters m_state_counters;
const bool m_can_push_task{false};
void consume(SpartaWorkerState<Input>* state, Input task) {
m_executor(state, task);
}
public:
SpartaWorkQueue(Executor,
unsigned int num_threads = parallel::default_num_threads(),
// push_tasks_while_running:
// * When this flag is true, all threads stay alive until the
// last task is finished. Useful when threads are adding
// more work to the queue via SpartaWorkerState::push_task.
// * When this flag is false, threads can
// exit as soon as there is no more work (to avoid
// preempting a thread that has useful work)
bool push_tasks_while_running = false);
// copies are not allowed
SpartaWorkQueue(const SpartaWorkQueue&) = delete;
// moves are allowed
SpartaWorkQueue(SpartaWorkQueue&&) = default;
void add_item(Input task);
/* Add an item on the queue of the given worker. */
void add_item(Input task, size_t worker_id);
/**
* Spawn threads and evaluate function. This method blocks.
*/
void run_all();
template <class>
friend class SpartaWorkerState;
};
template <class Input, typename Executor>
SpartaWorkQueue<Input, Executor>::SpartaWorkQueue(Executor executor,
unsigned int num_threads,
bool push_tasks_while_running)
: m_executor(executor),
m_num_threads(num_threads),
m_state_counters(num_threads),
m_can_push_task(push_tasks_while_running) {
assert(num_threads >= 1);
for (unsigned int i = 0; i < m_num_threads; ++i) {
m_states.emplace_back(std::make_unique<SpartaWorkerState<Input>>(
i, &m_state_counters, m_can_push_task));
}
}
template <class Input, typename Executor>
void SpartaWorkQueue<Input, Executor>::add_item(Input task) {
m_insert_idx = (m_insert_idx + 1) % m_num_threads;
assert(m_insert_idx < m_states.size());
m_states[m_insert_idx]->m_queue.push(task);
}
template <class Input, typename Executor>
void SpartaWorkQueue<Input, Executor>::add_item(Input task, size_t worker_id) {
assert(worker_id < m_states.size());
m_states[worker_id]->m_queue.push(task);
}
/*
* Each worker thread pulls from its own queue first, and then once finished
* looks randomly at other queues to try and steal work.
*/
template <class Input, typename Executor>
void SpartaWorkQueue<Input, Executor>::run_all() {
m_state_counters.num_non_empty = 0;
m_state_counters.num_running = 0;
m_state_counters.waiter->take_all();
auto worker = [&](SpartaWorkerState<Input>* state, size_t state_idx) {
auto attempts =
workqueue_impl::create_permutation(m_num_threads, state_idx);
while (true) {
auto have_task = false;
for (auto idx : attempts) {
auto other_state = m_states[idx].get();
auto task = other_state->pop_task(state);
if (task) {
have_task = true;
consume(state, *task);
break;
}
}
if (have_task) {
continue;
}
state->set_running(false);
if (!m_can_push_task) {
// New tasks can't be added. We don't need to wait for the currently
// running jobs to finish.
return;
}
// Let the thread quit if all the threads are not running and there
// is no task in any queue.
if (m_state_counters.num_running == 0 &&
m_state_counters.num_non_empty == 0) {
// Wake up everyone who might be waiting, so they can quit.
m_state_counters.waiter->give(m_state_counters.num_all);
return;
}
m_state_counters.waiter->take(); // Wait for work.
}
};
for (size_t i = 0; i < m_num_threads; ++i) {
if (!m_states[i]->m_queue.empty()) {
++m_state_counters.num_non_empty;
}
}
std::vector<boost::thread> all_threads;
all_threads.reserve(m_num_threads);
for (size_t i = 0; i < m_num_threads; ++i) {
boost::thread::attributes attrs;
attrs.set_stack_size(8 * 1024 * 1024);
all_threads.emplace_back(attrs,
std::bind<void>(worker, m_states[i].get(), i));
}
for (auto& thread : all_threads) {
thread.join();
}
for (size_t i = 0; i < m_num_threads; ++i) {
assert(m_states[i]->m_queue.empty());
}
}
namespace workqueue_impl {
// Helper classes so the type of Executor can be inferred
template <typename Input, typename Fn>
struct NoStateWorkQueueHelper {
Fn fn;
void operator()(SpartaWorkerState<Input>*, Input a) { fn(a); }
};
template <typename Input, typename Fn>
struct WithStateWorkQueueHelper {
Fn fn;
void operator()(SpartaWorkerState<Input>* state, Input a) { fn(state, a); }
};
} // namespace workqueue_impl
// These functions are the most convenient way to create a SpartaWorkQueue
template <class Input,
typename Fn,
typename std::enable_if<Arity<Fn>::value == 1, int>::type = 0>
SpartaWorkQueue<Input, workqueue_impl::NoStateWorkQueueHelper<Input, Fn>>
work_queue(const Fn& fn,
unsigned int num_threads = parallel::default_num_threads(),
bool push_tasks_while_running = false) {
return SpartaWorkQueue<Input,
workqueue_impl::NoStateWorkQueueHelper<Input, Fn>>(
workqueue_impl::NoStateWorkQueueHelper<Input, Fn>{fn},
num_threads,
push_tasks_while_running);
}
template <class Input,
typename Fn,
typename std::enable_if<Arity<Fn>::value == 2, int>::type = 0>
SpartaWorkQueue<Input, workqueue_impl::WithStateWorkQueueHelper<Input, Fn>>
work_queue(const Fn& fn,
unsigned int num_threads = parallel::default_num_threads(),
bool push_tasks_while_running = false) {
return SpartaWorkQueue<Input,
workqueue_impl::WithStateWorkQueueHelper<Input, Fn>>(
workqueue_impl::WithStateWorkQueueHelper<Input, Fn>{fn},
num_threads,
push_tasks_while_running);
}
} // namespace sparta