Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_scheduling.cc
1 #include "caffe2/core/net_async_scheduling.h"
2 
3 CAFFE2_DEFINE_bool(
4  caffe2_net_async_always_schedule_child,
5  false,
6  "Always schedule child chains from parent chain");
7 
8 namespace caffe2 {
9 
10 AsyncSchedulingNet::AsyncSchedulingNet(
11  const std::shared_ptr<const NetDef>& net_def,
12  Workspace* ws)
13  : AsyncNetBase(net_def, ws), running_(false) {
14  reset();
15 }
16 
17 void AsyncSchedulingNet::reset() {
18  processed_tasks_num_ = 0;
19  cleanup_ = false;
20  success_ = true;
21 
22  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
23  auto& task_ops = chains_[task_id];
24  auto& task_op_node = operator_nodes_[task_ops.front()];
25  task_op_node.runtime_parent_count_ = parents(task_id).size();
26  }
27  exception_messages_.clear();
28 }
29 
30 void AsyncSchedulingNet::Wait() {
31  std::unique_lock<std::mutex> lock(running_mutex_);
32  while (running_) {
33  running_cv_.wait(lock);
34  }
35 }
36 
37 void AsyncSchedulingNet::schedule(int task_id) {
38  const auto& device_option = event(task_id).GetDeviceOption();
39  pool(device_option)->run([this, task_id]() {
40  if (success_) {
41  int stream_id = stream(task_id);
42  asyncWait(task_id, stream_id, parents(task_id));
43  try {
44  run(task_id, stream_id);
45  } catch (const std::exception& e) {
46  std::unique_lock<std::mutex> lock(exception_mutex_);
47  exception_messages_.push_back(e.what());
48  success_ = false;
49  }
50  }
51 
52  auto task_count = ++processed_tasks_num_;
53 
54  for (auto child_id : children(task_id)) {
55  int parent_count = updateParentCount(child_id);
56  if (parent_count == 0) {
57  if (cleanup_ || FLAGS_caffe2_net_async_always_schedule_child ||
58  canSchedule(child_id)) {
59  schedule(child_id);
60  } else {
61  const auto& device_option = event(child_id).GetDeviceOption();
62  pool(device_option)
63  ->run(std::bind(
64  &AsyncSchedulingNet::pollAndSchedule, this, child_id));
65  }
66  }
67  }
68 
69  if (success_) {
70  if (task_count == tasksNum()) {
71  // All tasks are finished, polling thread is sleeping;
72  // only one thread enters here
73  finalizeEvents();
74  finishRun();
75  return;
76  }
77  } else {
78  // Before setting running_ to false and notifying waiters we need to
79  // 1. Ensure that only one thread does the cleanup
80  // 2. Ensure that all other pending tasks in workers and polling threads
81  // are finished and
82  // 3. Ensure that all tasks that were not scheduled have their events set
83  {
84  std::unique_lock<std::mutex> cleanup_lock(cleanup_mutex_);
85  if (cleanup_) {
86  return;
87  }
88  cleanup_ = true;
89  }
90 
91  // Errors are not recoverable and happen in exceptional cases,
92  // ok to busy wait
93  while (processed_tasks_num_ != tasksNum()) {
94  }
95 
96  // Make sure all events are set, wait for scheduled events
97  finalizeEvents();
98 
99  // Notify observers and waiters
100  finishRun();
101  }
102  });
103 }
104 
105 void AsyncSchedulingNet::pollAndSchedule(int task_id) {
106  if (canSchedule(task_id) || cleanup_) {
107  // force schedule the rest of the tasks if cleanup is started
108  schedule(task_id);
109  } else {
110  const auto& device_option = event(task_id).GetDeviceOption();
111  pool(device_option)
112  ->run(std::bind(&AsyncSchedulingNet::pollAndSchedule, this, task_id));
113  }
114 }
115 
116 int AsyncSchedulingNet::updateParentCount(int child_id) {
117  auto& child_ops = chains_[child_id];
118  auto& child_node = operator_nodes_[child_ops.front()];
119  int parent_count = --child_node.runtime_parent_count_;
120  CAFFE_ENFORCE_GE(parent_count, 0);
121  return parent_count;
122 }
123 
124 void AsyncSchedulingNet::finishRun() {
125  // notify observers and waiters
126  StopAllObservers();
127  running_ = false;
128  running_cv_.notify_all();
129 }
130 
131 bool AsyncSchedulingNet::DoRunAsync() {
132  std::unique_lock<std::mutex> lock(running_mutex_);
133  CAFFE_ENFORCE(!running_, "Concurrent RunAsync calls");
134  running_ = true;
135  reset();
136 
137  StartAllObservers();
138 
139  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
140  if (parents(task_id).empty()) {
141  schedule(task_id);
142  }
143  }
144 
145  return true;
146 }
147 
148 AsyncSchedulingNet::~AsyncSchedulingNet() {}
149 
150 REGISTER_NET(async_scheduling, AsyncSchedulingNet);
151 
152 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...