source/backend/cpu/ThreadPool.cpp (173 lines of code) (raw):

// // ThreadPool.cpp // MNN // // Created by MNN on 2019/06/30. // Copyright © 2018, Alibaba Group Holding Limited // #ifdef MNN_USE_THREAD_POOL #include "backend/cpu/ThreadPool.hpp" #include <string.h> #include <MNN/MNNDefine.h> #define MNN_THREAD_POOL_MAX_TASKS 2 namespace MNN { ThreadPool* ThreadPool::gInstance = nullptr; static std::mutex gInitMutex; int ThreadPool::init(int number) { if (1 >= number) { return 1; } std::lock_guard<std::mutex> _l(gInitMutex); if (nullptr != gInstance) { if (gInstance->number() < number) { return gInstance->number(); } } if (nullptr == gInstance) { gInstance = new ThreadPool(number); } return number; } void ThreadPool::destroy() { std::lock_guard<std::mutex> _l(gInitMutex); if (nullptr != gInstance) { delete gInstance; gInstance = nullptr; } } ThreadPool::ThreadPool(int numberThread) { mNumberThread = numberThread; mActiveCount.resize(numberThread); for (int i=0; i<numberThread; ++i) { mActiveCount[i] = new std::atomic_int(0); } mTaskAvailable.resize(MNN_THREAD_POOL_MAX_TASKS); mTasks.resize(MNN_THREAD_POOL_MAX_TASKS); for (int t = 0; t < mTasks.size(); ++t) { mTaskAvailable[t] = true; for (int i = 0; i < mNumberThread; ++i) { mTasks[t].second.emplace_back(new std::atomic_bool{false}); } } for (int i = 1; i < mNumberThread; ++i) { int threadIndex = i; mWorkers.emplace_back([this, threadIndex]() { while (!mStop) { while (*mActiveCount[threadIndex] > 0) { for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) { if (*mTasks[i].second[threadIndex]) { mTasks[i].first.first(threadIndex); { *mTasks[i].second[threadIndex] = false; } } } std::this_thread::yield(); } std::unique_lock<std::mutex> _l(mQueueMutex); mCondition.wait(_l, [this, threadIndex] { return mStop || *mActiveCount[threadIndex] > 0; }); } }); } } ThreadPool::~ThreadPool() { { std::lock_guard<std::mutex> _l(mQueueMutex); mStop = true; } mCondition.notify_all(); for (auto& worker : mWorkers) { worker.join(); } for (auto& task : mTasks) { for (auto c : task.second) { delete c; } } for (int i=0; i<mActiveCount.size(); ++i) { delete mActiveCount[i]; } } int ThreadPool::acquireWorkIndex() { if (nullptr == gInstance) { return -1; } std::lock_guard<std::mutex> _l(gInstance->mQueueMutex); for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) { if (gInstance->mTaskAvailable[i]) { gInstance->mTaskAvailable[i] = false; return i; } } return -1; } void ThreadPool::releaseWorkIndex(int index) { if (nullptr == gInstance) { return; } if (index < 0 || index >= MNN_THREAD_POOL_MAX_TASKS) { return; } std::lock_guard<std::mutex> _l(gInstance->mQueueMutex); gInstance->mTaskAvailable[index] = true; } void ThreadPool::active(int threadNumber) { if (nullptr == gInstance) { return; } { std::lock_guard<std::mutex> _l(gInstance->mQueueMutex); for (int i=0; i<threadNumber; ++i) { (*gInstance->mActiveCount[i])++; } } gInstance->mCondition.notify_all(); } void ThreadPool::deactive(int threadNumber) { if (nullptr == gInstance) { return; } for (int i=0; i<threadNumber; ++i) { (*gInstance->mActiveCount[i])--; } } void ThreadPool::enqueue(TASK&& task, int index, int threadNumber) { if (1 >= task.second || 0 > index) { for (int i = 0; i < task.second; ++i) { task.first(i); } return; } MNN_ASSERT(nullptr != gInstance); gInstance->enqueueInternal(std::move(task), index, threadNumber); } void ThreadPool::enqueueInternal(TASK&& task, int index, int threadNumber) { if (threadNumber <= 1) { for (int i = 0; i < task.second; ++i) { task.first(i); } return; } int workSize = task.second; if (workSize > threadNumber) { mTasks[index].first = std::make_pair( [workSize, &task, threadNumber, this](int tId) { for (int v = tId; v < workSize; v += threadNumber) { task.first(v); } },threadNumber); workSize = threadNumber; } else { mTasks[index].first = std::move(task); } { for (int i = 1; i < workSize; ++i) { *mTasks[index].second[i] = true; } } mTasks[index].first.first(0); bool complete = true; do { complete = true; for (int i = 1; i < workSize; ++i) { if (*mTasks[index].second[i]) { complete = false; break; } } std::this_thread::yield(); // FUNC_PRINT(notComplete); } while (!complete); } } // namespace MNN #endif