1 #include "caffe2/core/net_dag.h" 6 #include <unordered_map> 7 #include <unordered_set> 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/static_tracepoint.h" 11 #include "caffe2/core/timer.h" 12 #include "caffe2/proto/caffe2.pb.h" 13 #include "caffe2/utils/proto_utils.h" 16 caffe2_disable_chaining,
18 "Disable chaining logic (some latent multi-device issues).");
21 caffe2_dag_net_collect_stats,
23 "Collect time stats in DAG net");
27 DAGNetBase::DAGNetBase(
28 const std::shared_ptr<const NetDef>& net_def,
30 : NetBase(net_def, ws), caught_exception_yet_(false), iter_(0) {
32 VLOG(1) <<
"Constructing DAGNet " << net_def->name();
34 operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
37 (FLAGS_caffe2_disable_chaining
38 ? dag_utils::singleChains(operator_nodes_)
39 : dag_utils::computeChains(operator_nodes_));
41 operators_.reserve(operator_nodes_.size());
42 for (
const auto& node : operator_nodes_) {
43 operators_.push_back(node.operator_.get());
46 LOG(INFO) <<
"Number of parallel execution chains " 47 << execution_chains_.size()
48 <<
" Number of operators = " << net_def->op_size();
54 for (
int idx = 0; idx < operator_nodes_.size(); ++idx) {
55 if (operator_nodes_[idx].parents_.size() == 0) {
56 initial_frontier_.push_back(idx);
60 int num_workers = net_def->has_num_workers() ? net_def->num_workers() : 1;
61 CAFFE_ENFORCE(num_workers > 0,
"Must have a positive number of workers.");
62 if (num_workers == 1) {
63 LOG(WARNING) <<
"Number of workers is 1: this means that all operators " 64 <<
"will be executed sequentially. Did you forget to set " 65 <<
"num_workers in the NetDef?";
67 num_workers_ = num_workers;
69 for (
int idx = 0; idx < operator_nodes_.size(); ++idx) {
70 if (operator_nodes_[idx].is_chain_start_) {
71 task_timers_[idx] = caffe2::make_unique<Timer>();
74 stats_.reserve(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
75 for (
auto device_idx = 0;
76 device_idx < DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
79 "dag_net/stats/" + net_def->name() +
"/" +
80 caffe2::DeviceTypeName(device_idx));
84 DAGNetBase::~DAGNetBase() {
86 job_queue_->NoMoreJobs();
87 VLOG(1) <<
"Joining workers.";
88 for (
auto& worker : workers_) {
94 bool DAGNetBase::DoRunAsync() {
98 std::unique_lock<std::mutex> run_lock(run_in_progress_);
99 VLOG(1) <<
"Running parallel net.";
101 remaining_ops_ = operator_nodes_.size();
105 job_queue_ = caffe2::make_unique<SimpleQueue<int>>();
108 auto num_workers_to_start = num_workers_ - workers_.size();
112 for (
auto i = 0; i < num_workers_to_start; i++) {
113 VLOG(1) <<
"Start worker #" << workers_.size();
114 workers_.push_back(std::thread(&DAGNetBase::WorkerFunction,
this));
117 for (
auto& node : operator_nodes_) {
118 node.runtime_parent_count_ = node.parents_.size();
121 for (
auto& value : initial_frontier_) {
122 if (FLAGS_caffe2_dag_net_collect_stats) {
123 task_timers_[value]->Start();
125 job_queue_->Push(value);
129 std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
131 if (remaining_ops_ == 0 || !success_) {
134 cv_.wait(mutex_lock);
143 for (
auto& worker : workers_) {
147 job_queue_.reset(
nullptr);
148 #ifdef CAFFE2_USE_EXCEPTION_PTR 149 if (caught_exception_) {
151 caught_exception_yet_ =
false;
152 std::rethrow_exception(caught_exception_);
154 #endif // CAFFE2_USE_EXCEPTION_PTR 157 VLOG(2) <<
"All ops finished running.";
158 for (
const auto& op : operator_nodes_) {
160 op.runtime_parent_count_ == 0,
162 op.operator_->debug_def().name(),
164 op.operator_->debug_def().type(),
165 ") has some runtime parents left.");
173 void DAGNetBase::HandleException(
175 const std::string& exception_str) {
176 const std::string& operator_name =
177 operator_nodes_[operator_idx].operator_->debug_def().name();
178 const std::string& operator_type =
179 operator_nodes_[operator_idx].operator_->debug_def().type();
180 const char* prefix =
"Exception from operator chain starting at '";
181 #ifdef CAFFE2_USE_EXCEPTION_PTR 182 if (!caught_exception_yet_.exchange(
true)) {
183 caught_exception_ = std::current_exception();
185 prefix =
"Secondary exception from operator chain starting at '";
187 #endif // CAFFE2_USE_EXCEPTION_PTR 188 LOG(ERROR) << prefix << operator_name <<
"' (type '" << operator_type
189 <<
"'): " << exception_str <<
"\n";
190 #ifndef CAFFE2_USE_EXCEPTION_PTR 192 #endif // CAFFE2_USE_EXCEPTION_PTR 195 void DAGNetBase::WorkerFunction() {
203 if (!job_queue_->Pop(&idx)) {
206 if (FLAGS_caffe2_dag_net_collect_stats) {
208 operator_nodes_[idx].operator_->event().GetDeviceOption();
210 stats_[device_option.device_type()],
211 task_pool_wait_time_us,
212 task_timers_[idx]->MicroSeconds());
215 VLOG(1) <<
"Running chain starting at operator #" << idx <<
" " 216 << operator_nodes_[idx].operator_->debug_def().name() <<
"(" 217 << operator_nodes_[idx].operator_->debug_def().type() <<
").";
219 execution_chains_.find(idx) != execution_chains_.end(),
223 const auto& chain = execution_chains_[idx];
224 bool this_success =
false;
226 this_success = RunAt(idx, execution_chains_[idx]);
231 LOG(ERROR) <<
"Operator chain failed starting at: " 233 operator_nodes_[idx].operator_->debug_def());
235 }
catch (std::exception& e) {
236 std::string exception_str = GetExceptionString(e);
237 HandleException(idx, exception_str);
239 std::string exception_str =
"Unknown exception";
240 HandleException(idx, exception_str);
244 std::vector<int> chains_to_queue;
245 for (
const auto idx : chain) {
246 for (
const auto child : operator_nodes_[idx].children_) {
247 const int count = --operator_nodes_[child].runtime_parent_count_;
250 "Found runtime parent count smaller than zero for ",
252 operator_nodes_[child].operator_->debug_def().name(),
254 operator_nodes_[child].operator_->debug_def().type(),
261 if (operator_nodes_[child].is_chain_start_) {
262 VLOG(2) <<
"Pushing chain #" << child <<
" to queue.";
263 chains_to_queue.push_back(child);
270 std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
271 remaining_ops_ -= chain.size();
272 CAFFE_ENFORCE(remaining_ops_ >= 0);
273 success_ &= this_success;
274 if (remaining_ops_ == 0 || !success_) {
280 job_queue_->NoMoreJobs();
287 for (
const auto idx : chains_to_queue) {
288 if (FLAGS_caffe2_dag_net_collect_stats) {
289 task_timers_[idx]->Start();
291 job_queue_->Push(idx);
295 VLOG(2) <<
"Finished executing operator #" << idx;
300 const int warmup_runs,
302 const bool run_individual) {
303 std::cout <<
"Starting benchmark." << std::endl;
304 std::cout <<
"Running warmup runs." << std::endl;
307 "Number of warm up runs should be non negative, provided ",
310 for (
int i = 0; i < warmup_runs; ++i) {
311 CAFFE_ENFORCE(Run(),
"Warmup run ", i,
" has failed.");
314 std::cout <<
"Main runs." << std::endl;
317 "Number of main runs should be non negative, provided ",
321 for (
int i = 0; i < main_runs; ++i) {
322 CAFFE_ENFORCE(Run(),
"Main run ", i,
" has failed.");
325 std::cout <<
"Main run finished. Milliseconds per iter: " 326 << millis / main_runs
327 <<
". Iters per second: " << 1000.0 * main_runs / millis << std::endl;
329 if (run_individual) {
330 std::cout <<
"DAGNet does not do per-op benchmark. To do so, " 331 "switch to a simple net type." << std::endl;
333 return vector<float>{millis / main_runs};
336 bool DAGNet::RunAt(
int chain_id,
const std::vector<int>& chain) {
337 for (
const auto i : chain) {
338 #ifdef CAFFE2_ENABLE_SDT 339 const auto& op_name =
340 operator_nodes_[i].operator_->debug_def().name().c_str();
341 const auto& op_type =
342 operator_nodes_[i].operator_->debug_def().type().c_str();
343 auto* op_ptr = operator_nodes_[i].operator_.get();
344 const auto& net_name = name_.c_str();
345 CAFFE_SDT(operator_start, net_name, op_name, op_type, op_ptr);
347 const auto success = operator_nodes_[i].operator_->Run();
348 #ifdef CAFFE2_ENABLE_SDT 349 CAFFE_SDT(operator_done, net_name, op_name, op_type, op_ptr);
355 if (FLAGS_caffe2_dag_net_collect_stats) {
357 operator_nodes_[chain_id].operator_->event().GetDeviceOption();
359 stats_[device_option.device_type()],
360 task_time_to_succeeded_ms,
361 task_timers_[chain_id]->MilliSeconds());
366 REGISTER_NET(dag,
DAGNet);
vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual) override
Benchmarks a network.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
float MilliSeconds()
Returns the elapsed time in milliseconds.
A simple timer object for measuring time.