Caffe2 - C++ API
A deep learning, cross platform ML framework
thread_pool.h
1 #ifndef CAFFE2_UTILS_THREAD_POOL_H_
2 #define CAFFE2_UTILS_THREAD_POOL_H_
3 
4 #include <condition_variable>
5 #include <functional>
6 #include <mutex>
7 #include <queue>
8 #include <thread>
9 #include <utility>
10 
11 #include "caffe2/core/numa.h"
12 
13 namespace caffe2 {
14 
16  private:
17  struct task_element_t {
18  bool run_with_id;
19  const std::function<void()> no_id;
20  const std::function<void(std::size_t)> with_id;
21 
22  explicit task_element_t(const std::function<void()>& f)
23  : run_with_id(false), no_id(f), with_id(nullptr) {}
24  explicit task_element_t(const std::function<void(std::size_t)>& f)
25  : run_with_id(true), no_id(nullptr), with_id(f) {}
26  };
27 
28  std::queue<task_element_t> tasks_;
29  std::vector<std::thread> threads_;
30  std::mutex mutex_;
31  std::condition_variable condition_;
32  std::condition_variable completed_;
33  bool running_;
34  bool complete_;
35  std::size_t available_;
36  std::size_t total_;
37  int numa_node_id_;
38 
39  public:
40  explicit TaskThreadPool(std::size_t pool_size, int numa_node_id = -1)
41  : threads_(pool_size),
42  running_(true),
43  complete_(true),
44  available_(pool_size),
45  total_(pool_size),
46  numa_node_id_(numa_node_id) {
47  for (std::size_t i = 0; i < pool_size; ++i) {
48  threads_[i] = std::thread(std::bind(&TaskThreadPool::main_loop, this, i));
49  }
50  }
51 
52  // Set running flag to false then notify all threads.
53  ~TaskThreadPool() {
54  {
55  std::unique_lock<std::mutex> lock(mutex_);
56  running_ = false;
57  condition_.notify_all();
58  }
59 
60  try {
61  for (auto& t : threads_) {
62  t.join();
63  }
64  } catch (const std::exception&) {
65  }
66  }
67 
69  template <typename Task>
70  void runTask(Task task) {
71  std::unique_lock<std::mutex> lock(mutex_);
72 
73  // Set task and signal condition variable so that a worker thread will
74  // wake up and use the task.
75  tasks_.push(task_element_t(static_cast<std::function<void()>>(task)));
76  complete_ = false;
77  condition_.notify_one();
78  }
79 
80  void run(const std::function<void()>& func) {
81  runTask(func);
82  }
83 
84  template <typename Task>
85  void runTaskWithID(Task task) {
86  std::unique_lock<std::mutex> lock(mutex_);
87 
88  // Set task and signal condition variable so that a worker thread will
89  // wake up and use the task.
90  tasks_.push(
91  task_element_t(static_cast<std::function<void(std::size_t)>>(task)));
92  complete_ = false;
93  condition_.notify_one();
94  }
95 
98  std::unique_lock<std::mutex> lock(mutex_);
99  while (!complete_) {
100  completed_.wait(lock);
101  }
102  }
103 
104  private:
106  void main_loop(std::size_t index) {
107  NUMABind(numa_node_id_);
108 
109  while (running_) {
110  // Wait on condition variable while the task is empty and
111  // the pool is still running.
112  std::unique_lock<std::mutex> lock(mutex_);
113  while (tasks_.empty() && running_) {
114  condition_.wait(lock);
115  }
116  // If pool is no longer running, break out of loop.
117  if (!running_) {
118  break;
119  }
120 
121  // Copy task locally and remove from the queue. This is
122  // done within its own scope so that the task object is
123  // destructed immediately after running the task. This is
124  // useful in the event that the function contains
125  // shared_ptr arguments bound via bind.
126  {
127  auto tasks = tasks_.front();
128  tasks_.pop();
129  // Decrement count, indicating thread is no longer available.
130  --available_;
131 
132  lock.unlock();
133 
134  // Run the task.
135  try {
136  if (tasks.run_with_id) {
137  tasks.with_id(index);
138  } else {
139  tasks.no_id();
140  }
141  } catch (const std::exception&) {
142  }
143 
144  // Update status of empty, maybe
145  // Need to recover the lock first
146  lock.lock();
147 
148  // Increment count, indicating thread is available.
149  ++available_;
150  if (tasks_.empty() && available_ == total_) {
151  complete_ = true;
152  completed_.notify_one();
153  }
154  }
155  } // while running_
156  }
157 };
158 
159 } // namespace caffe2
160 
161 #endif // CAFFE2_UTILS_THREAD_POOL_H_
void runTask(Task task)
Add task to the thread pool if a thread is currently available.
Definition: thread_pool.h:70
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void waitWorkComplete()
Wait for queue to be empty.
Definition: thread_pool.h:97