1 #include "caffe2/core/net_async_scheduling.h" 4 caffe2_net_async_always_schedule_child,
6 "Always schedule child chains from parent chain");
10 AsyncSchedulingNet::AsyncSchedulingNet(
11 const std::shared_ptr<const NetDef>& net_def,
13 : AsyncNetBase(net_def, ws), running_(false) {
17 void AsyncSchedulingNet::reset() {
18 processed_tasks_num_ = 0;
22 for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
23 auto& task_ops = chains_[task_id];
24 auto& task_op_node = operator_nodes_[task_ops.front()];
25 task_op_node.runtime_parent_count_ = parents(task_id).size();
27 exception_messages_.clear();
30 void AsyncSchedulingNet::Wait() {
31 std::unique_lock<std::mutex> lock(running_mutex_);
33 running_cv_.wait(lock);
37 void AsyncSchedulingNet::schedule(
int task_id) {
38 const auto& device_option = event(task_id).GetDeviceOption();
39 pool(device_option)->run([
this, task_id]() {
41 int stream_id = stream(task_id);
42 asyncWait(task_id, stream_id, parents(task_id));
44 run(task_id, stream_id);
45 }
catch (
const std::exception& e) {
46 std::unique_lock<std::mutex> lock(exception_mutex_);
47 exception_messages_.push_back(e.what());
52 auto task_count = ++processed_tasks_num_;
54 for (
auto child_id : children(task_id)) {
55 int parent_count = updateParentCount(child_id);
56 if (parent_count == 0) {
57 if (cleanup_ || FLAGS_caffe2_net_async_always_schedule_child ||
58 canSchedule(child_id)) {
61 const auto& device_option = event(child_id).GetDeviceOption();
64 &AsyncSchedulingNet::pollAndSchedule,
this, child_id));
70 if (task_count == tasksNum()) {
84 std::unique_lock<std::mutex> cleanup_lock(cleanup_mutex_);
93 while (processed_tasks_num_ != tasksNum()) {
105 void AsyncSchedulingNet::pollAndSchedule(
int task_id) {
106 if (canSchedule(task_id) || cleanup_) {
110 const auto& device_option = event(task_id).GetDeviceOption();
112 ->run(std::bind(&AsyncSchedulingNet::pollAndSchedule,
this, task_id));
116 int AsyncSchedulingNet::updateParentCount(
int child_id) {
117 auto& child_ops = chains_[child_id];
118 auto& child_node = operator_nodes_[child_ops.front()];
119 int parent_count = --child_node.runtime_parent_count_;
120 CAFFE_ENFORCE_GE(parent_count, 0);
124 void AsyncSchedulingNet::finishRun() {
128 running_cv_.notify_all();
131 bool AsyncSchedulingNet::DoRunAsync() {
132 std::unique_lock<std::mutex> lock(running_mutex_);
133 CAFFE_ENFORCE(!running_,
"Concurrent RunAsync calls");
139 for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
140 if (parents(task_id).empty()) {
148 AsyncSchedulingNet::~AsyncSchedulingNet() {}
150 REGISTER_NET(async_scheduling, AsyncSchedulingNet);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...