Caffe2 - C++ API
A deep learning, cross platform ML framework
net_dag.cc
1 #include "caffe2/core/net_dag.h"
2 
3 #include <iostream>
4 #include <set>
5 #include <stack>
6 #include <unordered_map>
7 #include <unordered_set>
8 
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/static_tracepoint.h"
11 #include "caffe2/core/timer.h"
12 #include "caffe2/proto/caffe2.pb.h"
13 #include "caffe2/utils/proto_utils.h"
14 
15 CAFFE2_DEFINE_bool(
16  caffe2_disable_chaining,
17  false,
18  "Disable chaining logic (some latent multi-device issues).");
19 
20 CAFFE2_DEFINE_bool(
21  caffe2_dag_net_collect_stats,
22  false,
23  "Collect time stats in DAG net");
24 
25 namespace caffe2 {
26 
27 DAGNetBase::DAGNetBase(
28  const std::shared_ptr<const NetDef>& net_def,
29  Workspace* ws)
30  : NetBase(net_def, ws), caught_exception_yet_(false), iter_(0) {
31  // Blob creator allows us to track which operator created which blob.
32  VLOG(1) << "Constructing DAGNet " << net_def->name();
33 
34  operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
35 
36  execution_chains_ =
37  (FLAGS_caffe2_disable_chaining
38  ? dag_utils::singleChains(operator_nodes_)
39  : dag_utils::computeChains(operator_nodes_));
40 
41  operators_.reserve(operator_nodes_.size());
42  for (const auto& node : operator_nodes_) {
43  operators_.push_back(node.operator_.get());
44  }
45 
46  LOG(INFO) << "Number of parallel execution chains "
47  << execution_chains_.size()
48  << " Number of operators = " << net_def->op_size();
49  // TODO: do we want to make sure that there are no loops in the
50  // dependency graph?
51 
52  // Figure out the initial frontier - this is the one we will feed into the job
53  // queue to start a run.
54  for (int idx = 0; idx < operator_nodes_.size(); ++idx) {
55  if (operator_nodes_[idx].parents_.size() == 0) {
56  initial_frontier_.push_back(idx);
57  }
58  }
59  // Finally, start the workers.
60  int num_workers = net_def->has_num_workers() ? net_def->num_workers() : 1;
61  CAFFE_ENFORCE(num_workers > 0, "Must have a positive number of workers.");
62  if (num_workers == 1) {
63  LOG(WARNING) << "Number of workers is 1: this means that all operators "
64  << "will be executed sequentially. Did you forget to set "
65  << "num_workers in the NetDef?";
66  }
67  num_workers_ = num_workers;
68 
69  for (int idx = 0; idx < operator_nodes_.size(); ++idx) {
70  if (operator_nodes_[idx].is_chain_start_) {
71  task_timers_[idx] = caffe2::make_unique<Timer>();
72  }
73  }
74  stats_.reserve(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
75  for (auto device_idx = 0;
76  device_idx < DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
77  ++device_idx) {
78  stats_.emplace_back(
79  "dag_net/stats/" + net_def->name() + "/" +
80  caffe2::DeviceTypeName(device_idx));
81  }
82 }
83 
84 DAGNetBase::~DAGNetBase() {
85  if (job_queue_) {
86  job_queue_->NoMoreJobs();
87  VLOG(1) << "Joining workers.";
88  for (auto& worker : workers_) {
89  worker.join();
90  }
91  }
92 }
93 
94 bool DAGNetBase::DoRunAsync() {
95  StartAllObservers();
96 
97  // Lock run_in_progress_ to prevent concurrent Run()s.
98  std::unique_lock<std::mutex> run_lock(run_in_progress_);
99  VLOG(1) << "Running parallel net.";
100  // First, set up job queue.
101  remaining_ops_ = operator_nodes_.size();
102  success_ = true;
103  iter_++;
104  if (!job_queue_) {
105  job_queue_ = caffe2::make_unique<SimpleQueue<int>>();
106  }
107  // Figure out number of workers to start.
108  auto num_workers_to_start = num_workers_ - workers_.size();
109 
110  // Ensure the number of workers matches the defined in case
111  // any of the previously started threads terminated.
112  for (auto i = 0; i < num_workers_to_start; i++) {
113  VLOG(1) << "Start worker #" << workers_.size();
114  workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
115  }
116  // Initialize the runtime parent count.
117  for (auto& node : operator_nodes_) {
118  node.runtime_parent_count_ = node.parents_.size();
119  }
120  // Kickstart the job queue.
121  for (auto& value : initial_frontier_) {
122  if (FLAGS_caffe2_dag_net_collect_stats) {
123  task_timers_[value]->Start();
124  }
125  job_queue_->Push(value);
126  }
127  // Wait for failure or completed execution.
128  {
129  std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
130  for (;;) {
131  if (remaining_ops_ == 0 || !success_) {
132  break;
133  }
134  cv_.wait(mutex_lock);
135  }
136  }
137  // Wait for all workers to terminate after failure.
138  // If there is a failure, it is unlikely that the net is executed
139  // again without modifications. Therefore it's easier to let the
140  // workers terminate here, versus adding a drain state to make the
141  // sure the job queue is cleared.
142  if (!success_) {
143  for (auto& worker : workers_) {
144  worker.join();
145  }
146  workers_.clear();
147  job_queue_.reset(nullptr);
148 #ifdef CAFFE2_USE_EXCEPTION_PTR
149  if (caught_exception_) {
150  // Reset flag here in case Net gets run again
151  caught_exception_yet_ = false;
152  std::rethrow_exception(caught_exception_);
153  }
154 #endif // CAFFE2_USE_EXCEPTION_PTR
155  return success_;
156  }
157  VLOG(2) << "All ops finished running.";
158  for (const auto& op : operator_nodes_) {
159  CAFFE_ENFORCE(
160  op.runtime_parent_count_ == 0,
161  "Operator ",
162  op.operator_->debug_def().name(),
163  "(",
164  op.operator_->debug_def().type(),
165  ") has some runtime parents left.");
166  }
167 
168  StopAllObservers();
169  // If the above while loop finished, we know that the current run finished.
170  return success_;
171 }
172 
173 void DAGNetBase::HandleException(
174  int operator_idx,
175  const std::string& exception_str) {
176  const std::string& operator_name =
177  operator_nodes_[operator_idx].operator_->debug_def().name();
178  const std::string& operator_type =
179  operator_nodes_[operator_idx].operator_->debug_def().type();
180  const char* prefix = "Exception from operator chain starting at '";
181 #ifdef CAFFE2_USE_EXCEPTION_PTR
182  if (!caught_exception_yet_.exchange(true)) {
183  caught_exception_ = std::current_exception();
184  } else {
185  prefix = "Secondary exception from operator chain starting at '";
186  }
187 #endif // CAFFE2_USE_EXCEPTION_PTR
188  LOG(ERROR) << prefix << operator_name << "' (type '" << operator_type
189  << "'): " << exception_str << "\n";
190 #ifndef CAFFE2_USE_EXCEPTION_PTR
191  throw; // Can't capture for dispatch to other thread, re-throw here
192 #endif // CAFFE2_USE_EXCEPTION_PTR
193 }
194 
195 void DAGNetBase::WorkerFunction() {
196  // WorkerFunctions() is an infinite loop until there are no more jobs to run.
197  while (true) {
198  int idx = 0;
199 
200  // Return if there are no more operators to run (e.g. the
201  // DAGNetBase is destructing, or there was an error on another
202  // worker and we're cleaning up).
203  if (!job_queue_->Pop(&idx)) {
204  return;
205  }
206  if (FLAGS_caffe2_dag_net_collect_stats) {
207  auto device_option =
208  operator_nodes_[idx].operator_->event().GetDeviceOption();
209  CAFFE_EVENT(
210  stats_[device_option.device_type()],
211  task_pool_wait_time_us,
212  task_timers_[idx]->MicroSeconds());
213  }
214 
215  VLOG(1) << "Running chain starting at operator #" << idx << " "
216  << operator_nodes_[idx].operator_->debug_def().name() << "("
217  << operator_nodes_[idx].operator_->debug_def().type() << ").";
218  CAFFE_ENFORCE(
219  execution_chains_.find(idx) != execution_chains_.end(),
220  "Can't find chain ",
221  idx,
222  ".");
223  const auto& chain = execution_chains_[idx];
224  bool this_success = false;
225  try {
226  this_success = RunAt(idx, execution_chains_[idx]);
227 
228  if (!this_success) {
229  // If an exception was thrown, the operator def will get printed
230  // by Operator::Run[Async], but if no exception occurs we print it here.
231  LOG(ERROR) << "Operator chain failed starting at: "
232  << ProtoDebugString(
233  operator_nodes_[idx].operator_->debug_def());
234  }
235  } catch (std::exception& e) {
236  std::string exception_str = GetExceptionString(e);
237  HandleException(idx, exception_str);
238  } catch (...) {
239  std::string exception_str = "Unknown exception";
240  HandleException(idx, exception_str);
241  }
242 
243  // Do book-keeping
244  std::vector<int> chains_to_queue;
245  for (const auto idx : chain) {
246  for (const auto child : operator_nodes_[idx].children_) {
247  const int count = --operator_nodes_[child].runtime_parent_count_;
248  CAFFE_ENFORCE(
249  count >= 0,
250  "Found runtime parent count smaller than zero for ",
251  "operator node ",
252  operator_nodes_[child].operator_->debug_def().name(),
253  "(",
254  operator_nodes_[child].operator_->debug_def().type(),
255  ").");
256 
257  if (count != 0) {
258  continue;
259  }
260 
261  if (operator_nodes_[child].is_chain_start_) {
262  VLOG(2) << "Pushing chain #" << child << " to queue.";
263  chains_to_queue.push_back(child);
264  }
265  }
266  }
267 
268  // Notify the caller of Run
269  {
270  std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
271  remaining_ops_ -= chain.size();
272  CAFFE_ENFORCE(remaining_ops_ >= 0);
273  success_ &= this_success;
274  if (remaining_ops_ == 0 || !success_) {
275  cv_.notify_one();
276  }
277 
278  // Terminate thread if this or any other operator chain failed.
279  if (!success_) {
280  job_queue_->NoMoreJobs();
281  return;
282  }
283 
284  // Queue follow up operator chains.
285  // Can't do this inline because it can race with another thread
286  // calling NoMoreJobs(). So the lock needs to be held on push.
287  for (const auto idx : chains_to_queue) {
288  if (FLAGS_caffe2_dag_net_collect_stats) {
289  task_timers_[idx]->Start();
290  }
291  job_queue_->Push(idx);
292  }
293  }
294 
295  VLOG(2) << "Finished executing operator #" << idx;
296  }
297 }
298 
300  const int warmup_runs,
301  const int main_runs,
302  const bool run_individual) {
303  std::cout << "Starting benchmark." << std::endl;
304  std::cout << "Running warmup runs." << std::endl;
305  CAFFE_ENFORCE(
306  warmup_runs >= 0,
307  "Number of warm up runs should be non negative, provided ",
308  warmup_runs,
309  ".");
310  for (int i = 0; i < warmup_runs; ++i) {
311  CAFFE_ENFORCE(Run(), "Warmup run ", i, " has failed.");
312  }
313 
314  std::cout << "Main runs." << std::endl;
315  CAFFE_ENFORCE(
316  main_runs >= 0,
317  "Number of main runs should be non negative, provided ",
318  main_runs,
319  ".");
320  Timer timer;
321  for (int i = 0; i < main_runs; ++i) {
322  CAFFE_ENFORCE(Run(), "Main run ", i, " has failed.");
323  }
324  auto millis = timer.MilliSeconds();
325  std::cout << "Main run finished. Milliseconds per iter: "
326  << millis / main_runs
327  << ". Iters per second: " << 1000.0 * main_runs / millis << std::endl;
328 
329  if (run_individual) {
330  std::cout << "DAGNet does not do per-op benchmark. To do so, "
331  "switch to a simple net type." << std::endl;
332  }
333  return vector<float>{millis / main_runs};
334 }
335 
336 bool DAGNet::RunAt(int chain_id, const std::vector<int>& chain) {
337  for (const auto i : chain) {
338 #ifdef CAFFE2_ENABLE_SDT
339  const auto& op_name =
340  operator_nodes_[i].operator_->debug_def().name().c_str();
341  const auto& op_type =
342  operator_nodes_[i].operator_->debug_def().type().c_str();
343  auto* op_ptr = operator_nodes_[i].operator_.get();
344  const auto& net_name = name_.c_str();
345  CAFFE_SDT(operator_start, net_name, op_name, op_type, op_ptr);
346 #endif
347  const auto success = operator_nodes_[i].operator_->Run();
348 #ifdef CAFFE2_ENABLE_SDT
349  CAFFE_SDT(operator_done, net_name, op_name, op_type, op_ptr);
350 #endif
351  if (!success) {
352  return false;
353  }
354  }
355  if (FLAGS_caffe2_dag_net_collect_stats) {
356  auto device_option =
357  operator_nodes_[chain_id].operator_->event().GetDeviceOption();
358  CAFFE_EVENT(
359  stats_[device_option.device_type()],
360  task_time_to_succeeded_ms,
361  task_timers_[chain_id]->MilliSeconds());
362  }
363  return true;
364 }
365 
366 REGISTER_NET(dag, DAGNet);
367 
368 } // namespace caffe2
vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual) override
Benchmarks a network.
Definition: net_dag.cc:299
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
float MilliSeconds()
Returns the elapsed time in milliseconds.
Definition: timer.h:32
A simple timer object for measuring time.
Definition: timer.h:16