1 #ifndef CAFFE2_UTILS_THREAD_POOL_H_ 2 #define CAFFE2_UTILS_THREAD_POOL_H_ 4 #include <condition_variable> 11 #include "caffe2/core/numa.h" 17 struct task_element_t {
19 const std::function<void()> no_id;
20 const std::function<void(std::size_t)> with_id;
22 explicit task_element_t(
const std::function<
void()>& f)
23 : run_with_id(
false), no_id(f), with_id(
nullptr) {}
24 explicit task_element_t(
const std::function<
void(std::size_t)>& f)
25 : run_with_id(
true), no_id(
nullptr), with_id(f) {}
28 std::queue<task_element_t> tasks_;
29 std::vector<std::thread> threads_;
31 std::condition_variable condition_;
32 std::condition_variable completed_;
35 std::size_t available_;
40 explicit TaskThreadPool(std::size_t pool_size,
int numa_node_id = -1)
41 : threads_(pool_size),
44 available_(pool_size),
46 numa_node_id_(numa_node_id) {
47 for (std::size_t i = 0; i < pool_size; ++i) {
48 threads_[i] = std::thread(std::bind(&TaskThreadPool::main_loop,
this, i));
55 std::unique_lock<std::mutex> lock(mutex_);
57 condition_.notify_all();
61 for (
auto& t : threads_) {
64 }
catch (
const std::exception&) {
69 template <
typename Task>
71 std::unique_lock<std::mutex> lock(mutex_);
75 tasks_.push(task_element_t(
static_cast<std::function<void()>
>(task)));
77 condition_.notify_one();
80 void run(
const std::function<
void()>& func) {
84 template <
typename Task>
85 void runTaskWithID(
Task task) {
86 std::unique_lock<std::mutex> lock(mutex_);
91 task_element_t(
static_cast<std::function<void(std::size_t)>
>(task)));
93 condition_.notify_one();
98 std::unique_lock<std::mutex> lock(mutex_);
100 completed_.wait(lock);
106 void main_loop(std::size_t index) {
107 NUMABind(numa_node_id_);
112 std::unique_lock<std::mutex> lock(mutex_);
113 while (tasks_.empty() && running_) {
114 condition_.wait(lock);
127 auto tasks = tasks_.front();
136 if (tasks.run_with_id) {
137 tasks.with_id(index);
141 }
catch (
const std::exception&) {
150 if (tasks_.empty() && available_ == total_) {
152 completed_.notify_one();
161 #endif // CAFFE2_UTILS_THREAD_POOL_H_ void runTask(Task task)
Add task to the thread pool if a thread is currently available.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void waitWorkComplete()
Wait for queue to be empty.