1 #include "caffe2/core/net_async_polling.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/timer.h" 6 CAFFE2_DECLARE_bool(caffe2_dag_net_collect_stats);
10 AsyncPollingNet::AsyncPollingNet(
11 const std::shared_ptr<const NetDef>& net_def,
13 : AsyncNetBase(net_def, ws), running_(false) {
14 task_timers_.resize(tasksNum());
15 for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
16 task_timers_[task_id] = caffe2::make_unique<Timer>();
19 stats_.reserve(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
20 for (
auto device_idx = 0;
21 device_idx < DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
24 "async_net/stats/" + net_def->name() +
"/" +
25 caffe2::DeviceTypeName(device_idx));
31 bool AsyncPollingNet::DoRunAsync() {
32 CAFFE_ENFORCE(!running_,
"Concurrent RunAsync calls");
39 bool success = pollAndSchedule();
40 if (FLAGS_caffe2_dag_net_collect_stats) {
41 CAFFE_EVENT(stats_[CPU], poll_time_ms, timer.MilliSeconds());
52 void AsyncPollingNet::schedule(
int task_id) {
53 if (FLAGS_caffe2_dag_net_collect_stats) {
54 task_timers_[task_id]->Start();
56 const auto& device_option = event(task_id).GetDeviceOption();
57 pool(device_option)->run([
this, task_id, device_option]() {
58 int stream_id = stream(task_id);
60 if (FLAGS_caffe2_dag_net_collect_stats) {
62 stats_[device_option.device_type()],
63 task_pool_wait_time_us,
64 task_timers_[task_id]->MicroSeconds());
71 asyncWait(task_id, stream_id, parents(task_id));
73 if (FLAGS_caffe2_dag_net_collect_stats) {
75 run(task_id, stream_id);
77 stats_[device_option.device_type()],
79 run_time.MicroSeconds());
81 run(task_id, stream_id);
83 }
catch (
const std::exception&) {
84 has_chain_failed_ =
true;
89 void AsyncPollingNet::reset() {
91 status_.resize(tasksNum(), EventStatus::EVENT_INITIALIZED);
92 has_chain_failed_ =
false;
95 bool AsyncPollingNet::pollAndSchedule() {
96 std::unordered_set<int> scheduled_tasks;
97 std::unordered_set<int> current_tasks;
99 for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
100 if (parents(task_id).empty()) {
101 current_tasks.insert(task_id);
102 scheduled_tasks.insert(task_id);
108 while (!current_tasks.empty()) {
109 std::unordered_set<int> updated_tasks;
110 std::unordered_set<int> next_tasks;
111 updated_tasks.reserve(current_tasks.size());
113 if (FLAGS_caffe2_dag_net_collect_stats) {
116 if (has_chain_failed_) {
117 finishTasks(current_tasks);
120 for (
auto& task_id : current_tasks) {
121 auto prev_status = status_[task_id];
122 status_[task_id] = query(task_id);
123 if (status_[task_id] == EventStatus::EVENT_FAILED) {
124 finishTasks(current_tasks);
128 if (prev_status != status_[task_id]) {
129 updated_tasks.insert(task_id);
130 if (FLAGS_caffe2_dag_net_collect_stats) {
131 updateTaskStats(task_id);
135 if (status_[task_id] != EventStatus::EVENT_SUCCESS) {
136 next_tasks.insert(task_id);
139 if (FLAGS_caffe2_dag_net_collect_stats) {
141 stats_[CPU], poll_status_update_time_us, timer.MicroSeconds());
144 std::unordered_set<int> visited_children;
145 for (
auto& task_id : updated_tasks) {
147 status_[task_id] == EventStatus::EVENT_SCHEDULED ||
148 status_[task_id] == EventStatus::EVENT_SUCCESS);
150 for (
auto& child_id : children(task_id)) {
151 if (!visited_children.count(child_id)) {
152 visited_children.insert(child_id);
157 if (!scheduled_tasks.count(child_id) &&
158 canSchedule(child_id, &status_)) {
159 next_tasks.insert(child_id);
160 scheduled_tasks.insert(child_id);
167 current_tasks.swap(next_tasks);
172 void AsyncPollingNet::updateTaskStats(
int task_id) {
173 const auto& device_option = event(task_id).GetDeviceOption();
174 if (status_[task_id] == EventStatus::EVENT_SCHEDULED) {
176 stats_[device_option.device_type()],
177 task_time_to_scheduled_us,
178 task_timers_[task_id]->MicroSeconds());
180 if (status_[task_id] == EventStatus::EVENT_SUCCESS) {
182 stats_[device_option.device_type()],
183 task_time_to_succeeded_ms,
184 task_timers_[task_id]->MilliSeconds());
188 AsyncPollingNet::~AsyncPollingNet() {}
190 REGISTER_NET(async_polling, AsyncPollingNet);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...