source/static_thread_pool.cpp (116 lines of code) (raw):
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unifex/static_thread_pool.hpp>
namespace unifex {
namespace _static_thread_pool {
context::context()
: context(std::thread::hardware_concurrency()) {}
context::context(std::uint32_t threadCount)
: threadCount_(threadCount)
, threadStates_(threadCount)
, nextThread_(0) {
UNIFEX_ASSERT(threadCount > 0);
threads_.reserve(threadCount);
UNIFEX_TRY {
for (std::uint32_t i = 0; i < threadCount; ++i) {
threads_.emplace_back([this, i] { run(i); });
}
} UNIFEX_CATCH (...) {
request_stop();
join();
UNIFEX_RETHROW();
}
}
context::~context() {
request_stop();
join();
}
void context::request_stop() noexcept {
for (auto& state : threadStates_) {
state.request_stop();
}
}
void context::run(std::uint32_t index) noexcept {
while (true) {
task_base* task = nullptr;
for (std::uint32_t i = 0; i < threadCount_; ++i) {
auto queueIndex = (index + i) < threadCount_
? (index + i)
: (index + i - threadCount_);
auto& state = threadStates_[queueIndex];
task = state.try_pop();
if (task != nullptr) {
break;
}
}
if (task == nullptr) {
task = threadStates_[index].pop();
if (task == nullptr) {
// request_stop() was called.
return;
}
}
task->execute(task);
}
}
void context::join() noexcept {
for (auto& t : threads_) {
t.join();
}
threads_.clear();
}
void context::enqueue(task_base* task) noexcept {
const std::uint32_t threadCount = static_cast<std::uint32_t>(threads_.size());
const std::uint32_t startIndex =
nextThread_.fetch_add(1, std::memory_order_relaxed) % threadCount;
// First try to enqueue to one of the threads without blocking.
for (std::uint32_t i = 0; i < threadCount; ++i) {
const auto index = (startIndex + i) < threadCount
? (startIndex + i)
: (startIndex + i - threadCount);
if (threadStates_[index].try_push(task)) {
return;
}
}
// Otherwise, do a blocking enqueue on the selected thread.
threadStates_[startIndex].push(task);
}
task_base* context::thread_state::try_pop() {
std::unique_lock lk{mut_, std::try_to_lock};
if (!lk || queue_.empty()) {
return nullptr;
}
return queue_.pop_front();
}
task_base* context::thread_state::pop() {
std::unique_lock lk{mut_};
while (queue_.empty()) {
if (stopRequested_) {
return nullptr;
}
cv_.wait(lk);
}
return queue_.pop_front();
}
bool context::thread_state::try_push(task_base* task) {
std::unique_lock lk{mut_, std::try_to_lock};
if (!lk) {
return false;
}
const bool wasEmpty = queue_.empty();
queue_.push_back(task);
if (wasEmpty) {
cv_.notify_one();
}
return true;
}
void context::thread_state::push(task_base* task) {
std::lock_guard lk{mut_};
const bool wasEmpty = queue_.empty();
queue_.push_back(task);
if (wasEmpty) {
cv_.notify_one();
}
}
void context::thread_state::request_stop() {
std::lock_guard lk{mut_};
stopRequested_ = true;
cv_.notify_one();
}
} // namespace _static_thread_pool
} // namespace unifex