Caffe2 - C++ API
A deep learning, cross platform ML framework
WorkersPool.h
1 #pragma once
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/logging.h"
5 #include <atomic>
6 #include <thread>
7 #include <condition_variable>
8 
9 #if defined(_MSC_VER)
10 #include <intrin.h>
11 #endif
12 
13 namespace caffe2 {
14 
15 // Uses code derived from gemmlowp,
16 // https://github.com/google/gemmlowp/blob/6c91e1ed0c2eff1182d804310b92911fe9c18019/internal/multi_thread_gemm.h
17 // Changes:
18 // - allocation-free execute()
19 // - Use RAII where possible.
20 // - Run the first task on the main thread (since that is the largest task).
21 // - removed custom allocator.
22 // - Removed some ifdef's
23 // - cache-line align Worker.
24 // - use std::atomic instead of volatile and custom barriers.
25 // - use std::mutex/std::condition_variable instead of raw pthreads.
26 
27 constexpr size_t kGEMMLOWPCacheLineSize = 64;
28 
29 template <typename T>
30 struct AllocAligned {
31  // Allocate a T aligned at an `align` byte address
32  template <typename... Args>
33  static T* alloc(Args&&... args) {
34  void* p = nullptr;
35 
36 #if defined(__ANDROID__)
37  p = memalign(kGEMMLOWPCacheLineSize, sizeof(T));
38 #elif defined(_MSC_VER)
39  p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize);
40 #else
41  posix_memalign((void**)&p, kGEMMLOWPCacheLineSize, sizeof(T));
42 #endif
43 
44  if (p) {
45  return new (p) T(std::forward<Args>(args)...);
46  }
47 
48  return nullptr;
49  }
50 
51  // Free a T previously allocated via AllocAligned<T>::alloc()
52  static void release(T* p) {
53  if (p) {
54  p->~T();
55  free((void*)p);
56  }
57  }
58 };
59 
60 // Deleter object for unique_ptr for an aligned object
61 template <typename T>
63  void operator()(T* p) const { AllocAligned<T>::release(p); }
64 };
65 
66 // make_unique that guarantees alignment
67 template <typename T>
68 struct MakeAligned {
69  template <typename... Args>
70  static std::unique_ptr<T, AlignedDeleter<T>> make(Args&&... args) {
71  return std::unique_ptr<T, AlignedDeleter<T>>(
72  AllocAligned<T>::alloc(std::forward<Args>(args)...));
73  }
74 };
75 
76 const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
77 
78 #if defined(_MSC_VER)
79 #define GEMMLOWP_NOP __nop();
80 #else
81 #define GEMMLOWP_NOP "nop\n"
82 #endif
83 
84 #define GEMMLOWP_STRING_CONCAT_4(X) X X X X
85 #define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
86 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
87 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
88 
89 inline int Do256NOPs() {
90 #if defined(_MSC_VER)
91  GEMMLOWP_NOP64;
92 #else
93  asm volatile(GEMMLOWP_NOP64);
94 #endif
95  return 64;
96 }
97 
98 #undef GEMMLOWP_STRING_CONCAT_4
99 #undef GEMMLOWP_NOP256
100 #undef GEMMLOWP_NOP64
101 #undef GEMMLOWP_NOP16
102 #undef GEMMLOWP_NOP4
103 #undef GEMMLOWP_NOP
104 
105 // Waits until *var != initial_value.
106 //
107 // Returns the new value of *var. The guarantee here is that
108 // the return value is different from initial_value, and that that
109 // new value has been taken by *var at some point during the
110 // execution of this function. There is no guarantee that this is
111 // still the value of *var when this function returns, since *var is
112 // not assumed to be guarded by any lock.
113 //
114 // First does some busy-waiting for a fixed number of no-op cycles,
115 // then falls back to passive waiting for the given condvar, guarded
116 // by the given mutex.
117 //
118 // The idea of doing some initial busy-waiting is to help get
119 // better and more consistent multithreading benefits for small GEMM sizes.
120 // Busy-waiting help ensuring that if we need to wake up soon after having
121 // started waiting, then we can wake up quickly (as opposed to, say,
122 // having to wait to be scheduled again by the OS). On the other hand,
123 // we must still eventually revert to passive waiting for longer waits
124 // (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
125 // so as to avoid permanently spinning.
126 //
127 template <typename T>
128 T WaitForVariableChange(std::atomic<T>* var,
129  T initial_value,
130  std::condition_variable* cond,
131  std::mutex* mutex) {
132  // If we are on a platform that supports it, spin for some time.
133  {
134  int nops = 0;
135  // First, trivial case where the variable already changed value.
136  T new_value = var->load(std::memory_order_relaxed);
137  if (new_value != initial_value) {
138  std::atomic_thread_fence(std::memory_order_acquire);
139  return new_value;
140  }
141  // Then try busy-waiting.
142  while (nops < kMaxBusyWaitNOPs) {
143  nops += Do256NOPs();
144  new_value = var->load(std::memory_order_relaxed);
145  if (new_value != initial_value) {
146  std::atomic_thread_fence(std::memory_order_acquire);
147  return new_value;
148  }
149  }
150  }
151 
152  // Finally, do real passive waiting.
153  {
154  std::unique_lock<std::mutex> g(*mutex);
155  T new_value = var->load(std::memory_order_relaxed);
156  // Handle spurious wakeups.
157  cond->wait(g, [&]() {
158  new_value = var->load(std::memory_order_relaxed);
159  return new_value != initial_value;
160  });
161  DCHECK_NE(static_cast<size_t>(new_value), static_cast<size_t>(initial_value));
162  return new_value;
163  }
164 }
165 
166 // A BlockingCounter lets one thread to wait for N events to occur.
167 // This is how the master thread waits for all the worker threads
168 // to have finished working.
170  public:
171  // Sets/resets the counter; initial_count is the number of
172  // decrementing events that the Wait() call will be waiting for.
173  void Reset(std::size_t initial_count) {
174  std::lock_guard<std::mutex> g(mutex_);
175  DCHECK_EQ(count_, 0);
176  count_ = initial_count;
177  }
178 
179  // Decrements the counter; if the counter hits zero, signals
180  // the thread that was waiting for that, and returns true.
181  // Otherwise (if the decremented count is still nonzero),
182  // returns false.
183  bool DecrementCount() {
184  const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1;
185  DCHECK_GE(count_value, 0);
186  if (count_value == 0) {
187  std::lock_guard<std::mutex> g(mutex_);
188  cond_.notify_one();
189  }
190  bool retval = count_value == 0;
191  return retval;
192  }
193 
194  // Waits for the N other threads (N having been set by Reset())
195  // to hit the BlockingCounter.
196  void Wait() {
197  while (size_t count_value = count_.load(std::memory_order_relaxed)) {
198  WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
199  }
200  }
201 
202  private:
203  std::condition_variable cond_;
204  std::mutex mutex_;
205  std::atomic<std::size_t> count_{0};
206 };
207 
208 // A workload for a worker.
209 struct Task {
210  Task() {}
211  virtual ~Task() {}
212  virtual void Run() = 0;
213 };
214 
215 // A worker thread.
216 class alignas(kGEMMLOWPCacheLineSize) Worker {
217  public:
218  enum class State : uint8_t {
219  ThreadStartup, // The initial state before the thread main loop runs.
220  Ready, // Is not working, has not yet received new work to do.
221  HasWork, // Has work to do.
222  ExitAsSoonAsPossible // Should exit at earliest convenience.
223  };
224 
225  explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
226  : task_(nullptr),
227  state_(State::ThreadStartup),
228  counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
229  thread_ = caffe2::make_unique<std::thread>([this]() { this->ThreadFunc(); });
230  }
231 
232  ~Worker() {
233  ChangeState(State::ExitAsSoonAsPossible);
234  thread_->join();
235  }
236 
237  // Changes State; may be called from either the worker thread
238  // or the master thread; however, not all state transitions are legal,
239  // which is guarded by assertions.
240  void ChangeState(State new_state) {
241  std::lock_guard<std::mutex> g(state_mutex_);
242  DCHECK(new_state != state_.load(std::memory_order_relaxed));
243  switch (state_.load(std::memory_order_relaxed)) {
244  case State::ThreadStartup:
245  DCHECK(new_state == State::Ready);
246  break;
247  case State::Ready:
248  DCHECK(new_state == State::HasWork || new_state == State::ExitAsSoonAsPossible);
249  break;
250  case State::HasWork:
251  DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible);
252  break;
253  default:
254  abort();
255  }
256  state_.store(new_state, std::memory_order_relaxed);
257  state_cond_.notify_one();
258  if (new_state == State::Ready) {
259  counter_to_decrement_when_ready_->DecrementCount();
260  }
261  }
262 
263  // Thread entry point.
264  void ThreadFunc() {
265  ChangeState(State::Ready);
266 
267  // Thread main loop
268  while (true) {
269  // Get a state to act on
270  // In the 'Ready' state, we have nothing to do but to wait until
271  // we switch to another state.
272  State state_to_act_upon =
273  WaitForVariableChange(&state_, State::Ready, &state_cond_, &state_mutex_);
274 
275  // We now have a state to act on, so act.
276  switch (state_to_act_upon) {
277  case State::HasWork:
278  // Got work to do! So do it, and then revert to 'Ready' state.
279  DCHECK(task_);
280  task_->Run();
281  task_ = nullptr;
282  ChangeState(State::Ready);
283  break;
284  case State::ExitAsSoonAsPossible:
285  return;
286  default:
287  abort();
288  }
289  }
290  }
291 
292  static void* ThreadFunc(void* arg) {
293  static_cast<Worker*>(arg)->ThreadFunc();
294  return nullptr;
295  }
296 
297  // Called by the master thead to give this worker work to do.
298  // It is only legal to call this if the worker
299  void StartWork(Task* task) {
300  DCHECK(!task_);
301  task_ = task;
302  DCHECK(state_.load(std::memory_order_acquire) == State::Ready);
303  ChangeState(State::HasWork);
304  }
305 
306  private:
307  // The underlying thread.
308  std::unique_ptr<std::thread> thread_;
309 
310  // The task to be worked on.
311  // Visibility of writes to task_ guarded by state_mutex_.
312  Task* task_;
313 
314  // The condition variable and mutex guarding state changes.
315  std::condition_variable state_cond_;
316  std::mutex state_mutex_;
317 
318  // The state enum tells if we're currently working, waiting for work, etc.
319  std::atomic<State> state_;
320 
321  // pointer to the master's thread BlockingCounter object, to notify the
322  // master thread of when this worker switches to the 'Ready' state.
323  BlockingCounter* const counter_to_decrement_when_ready_;
324 };
325 
326 class WorkersPool {
327  public:
328  WorkersPool() {}
329 
330  void Execute(const std::vector<std::shared_ptr<Task>>& tasks) {
331  CAFFE_ENFORCE_GE(tasks.size(), 1);
332  // One of the tasks will be run on the current thread.
333  int workers_count = tasks.size() - 1;
334  CreateWorkers(workers_count);
335  DCHECK_LE(workers_count, workers_.size());
336  counter_to_decrement_when_ready_.Reset(workers_count);
337  for (auto task = 1; task < tasks.size(); ++task) {
338  workers_[task - 1]->StartWork(tasks[task].get());
339  }
340  // Execute the remaining workload immediately on the current thread.
341  auto& task = tasks.front();
342  task->Run();
343  // Wait for the workers submitted above to finish.
344  counter_to_decrement_when_ready_.Wait();
345  }
346 
347  private:
348  // Ensures that the pool has at least the given count of workers.
349  // If any new worker has to be created, this function waits for it to
350  // be ready.
351  void CreateWorkers(std::size_t workers_count) {
352  if (workers_.size() >= workers_count) {
353  return;
354  }
355  counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
356  while (workers_.size() < workers_count) {
357  workers_.push_back(MakeAligned<Worker>::make(&counter_to_decrement_when_ready_));
358  }
359  counter_to_decrement_when_ready_.Wait();
360  }
361 
362  DISABLE_COPY_AND_ASSIGN(WorkersPool);
363  std::vector<std::unique_ptr<Worker, AlignedDeleter<Worker>>> workers_;
364  // The BlockingCounter used to wait for the workers.
365  BlockingCounter counter_to_decrement_when_ready_;
366 };
367 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...