Caffe2 - C++ API
A deep learning, cross platform ML framework
net_simple_async.cc
1 #include "caffe2/core/net_simple_async.h"
2 #include "caffe2/core/net.h"
3 
4 #include <iostream>
5 #include <set>
6 #include <unordered_map>
7 #include <unordered_set>
8 
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/static_tracepoint.h"
11 #include "caffe2/core/timer.h"
12 #include "caffe2/proto/caffe2.pb.h"
13 #include "caffe2/utils/proto_utils.h"
14 
15 namespace caffe2 {
16 
17 AsyncSimpleNet::AsyncSimpleNet(
18  const std::shared_ptr<const NetDef>& net_def,
19  Workspace* ws)
20  : NetBase(net_def, ws) {
21  VLOG(1) << "Constructing AsyncSimpleNet " << net_def->name();
22  const bool net_def_has_device_option = net_def->has_device_option();
23  // Initialize the operators
24  const DeviceOption* first_device_option = nullptr;
25  const DeviceOption* current_device_option;
26  for (int idx = 0; idx < net_def->op_size(); ++idx) {
27  const auto& operator_def = net_def->op(idx);
28  VLOG(1) << "Creating operator " << operator_def.name() << ": "
29  << operator_def.type();
30  std::unique_ptr<OperatorBase> op{nullptr};
31  if (!operator_def.has_device_option() && net_def_has_device_option) {
32  // In the case that the operator def does not specify a device option but
33  // the net def has a default option, we copy the device option over to the
34  // operator def.
35  OperatorDef temp_def(operator_def);
36  temp_def.mutable_device_option()->CopyFrom(net_def->device_option());
37  op = CreateOperator(temp_def, ws, idx);
38  current_device_option = &net_def->device_option();
39  } else {
40  op = CreateOperator(operator_def, ws, idx);
41  op->set_debug_def(
42  std::shared_ptr<const OperatorDef>{net_def, &(net_def->op(idx))});
43  current_device_option = &operator_def.device_option();
44  }
45  if (!first_device_option) {
46  first_device_option = current_device_option;
47  } else {
48  CAFFE_ENFORCE(
49  IsSameDevice(*first_device_option, *current_device_option),
50  "AsyncSimpleNet supports only single device networks");
51  }
52  operators_.emplace_back(std::move(op));
53  }
54  events_ = {&operators_.back()->event()};
55 }
56 
57 bool AsyncSimpleNet::DoRunAsync() {
58  StartAllObservers();
59 
60  VLOG(1) << "Running net " << name_;
61  for (auto& op : operators_) {
62  VLOG(1) << "Running operator " << op->debug_def().name() << "("
63  << op->debug_def().type() << ").";
64 #ifdef CAFFE2_ENABLE_SDT
65  const auto& op_name = op->debug_def().name().c_str();
66  const auto& op_type = op->debug_def().type().c_str();
67  auto* op_ptr = op.get();
68  const auto& net_name = name_.c_str();
69  CAFFE_SDT(operator_start, net_name, op_name, op_type, op_ptr);
70 #endif
71  bool res = op->RunAsync();
72 #ifdef CAFFE2_ENABLE_SDT
73  CAFFE_SDT(operator_done, net_name, op_name, op_type, op_ptr);
74 #endif
75  if (!res) {
76  LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
77  return false;
78  }
79  }
80  StopAllObservers();
81  return true;
82 }
83 
85  const int warmup_runs,
86  const int main_runs,
87  const bool run_individual) {
88  std::cout << "Starting benchmark." << std::endl;
89  std::cout << "Running warmup runs." << std::endl;
90  CAFFE_ENFORCE(
91  warmup_runs >= 0,
92  "Number of warm up runs should be non negative, provided ",
93  warmup_runs,
94  ".");
95  for (int i = 0; i < warmup_runs; ++i) {
96  CAFFE_ENFORCE(Run(), "Warmup run ", i, " has failed.");
97  }
98 
99  std::cout << "Main runs." << std::endl;
100  CAFFE_ENFORCE(
101  main_runs >= 0,
102  "Number of main runs should be non negative, provided ",
103  main_runs,
104  ".");
105  Timer timer;
106  for (int i = 0; i < main_runs; ++i) {
107  CAFFE_ENFORCE(Run(), "Main run ", i, " has failed.");
108  }
109  auto millis = timer.MilliSeconds();
110  std::cout << "Main run finished. Milliseconds per iter: "
111  << millis / main_runs
112  << ". Iters per second: " << 1000.0 * main_runs / millis << std::endl;
113 
114  if (run_individual) {
115  std::cout << "AsyncSimpleNet does not do per-op benchmark. To do so, "
116  "switch to a simple net type." << std::endl;
117  }
118  return vector<float>{millis / main_runs};
119 }
120 
121 REGISTER_NET(async_simple, AsyncSimpleNet);
122 
123 } // namespace caffe2
vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual) override
Benchmarks a network.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
float MilliSeconds()
Returns the elapsed time in milliseconds.
Definition: timer.h:32
A simple timer object for measuring time.
Definition: timer.h:16