1 #include "caffe2/core/net_dag_utils.h" 5 #include <unordered_map> 6 #include <unordered_set> 8 #include "caffe2/core/operator.h" 9 #include "caffe2/core/static_tracepoint.h" 10 #include "caffe2/core/timer.h" 11 #include "caffe2/proto/caffe2.pb.h" 12 #include "caffe2/utils/proto_utils.h" 18 void prune(
int node_idx, std::vector<OpGraphNode>& nodes) {
20 std::vector<bool> ancestors(nodes.size(),
false);
22 std::stack<std::pair<int, int>> nodes_stack;
24 nodes_stack.push(std::make_pair(node_idx, -1));
26 while (!nodes_stack.empty()) {
27 const auto& node_pair = nodes_stack.top();
28 int curr = node_pair.
first;
29 int prev = node_pair.second;
33 CAFFE_ENFORCE(curr < ancestors.size(),
"Out of bound access");
34 if (ancestors[curr]) {
35 ancestors[curr] =
false;
45 std::vector<int> new_parents;
46 for (
auto parent : nodes[curr].parents_) {
47 if (parent != prev && ancestors[parent]) {
49 nodes[parent].children_.erase(
51 nodes[parent].children_.begin(),
52 nodes[parent].children_.end(),
54 nodes[parent].children_.end());
56 new_parents.push_back(parent);
59 nodes[curr].parents_ = new_parents;
62 ancestors[curr] =
true;
65 if (nodes[curr].visited_inputs == nodes[curr].num_orig_parents) {
66 const auto& children = nodes[curr].children_;
67 for (
auto child : children) {
68 nodes[child].visited_inputs++;
69 nodes_stack.push(std::make_pair(child, curr));
79 std::vector<OpGraphNode> pruneOpNodeGraph(
80 const std::vector<OperatorNode>& nodes) {
82 std::vector<OpGraphNode> pruned;
88 for (
auto& node : nodes) {
90 nd.children_ = node.children_;
91 nd.parents_ = node.parents_;
92 nd.num_orig_parents = nd.parents_.size();
96 for (
int i = 0; i < pruned.size(); ++i) {
97 if (pruned[i].parents_.size() == 0) {
102 LOG(INFO) <<
"Operator graph pruning prior to chain compute took: " 103 << t.Seconds() <<
" secs";
107 void updateOperatorNodes(
108 std::vector<OperatorNode>& nodes,
109 const ExecutionChains& chains) {
110 for (
int i = 0; i < nodes.size(); ++i) {
111 auto& node = nodes[i];
112 if (chains.find(i) != chains.end()) {
113 node.is_chain_start_ =
true;
115 node.is_chain_start_ =
false;
117 node.runtime_parent_count_ = 0;
122 ExecutionChains computeChains(std::vector<OperatorNode>& orig_nodes) {
123 const std::vector<OpGraphNode> nodes = pruneOpNodeGraph(orig_nodes);
124 vector<int> initial_frontier;
125 for (
int idx = 0; idx < nodes.size(); ++idx) {
126 if (nodes[idx].parents_.size() == 0) {
127 initial_frontier.push_back(idx);
133 std::unordered_map<int, int> node_seen_count;
135 for (
int root_index : initial_frontier) {
136 const auto& root = nodes[root_index];
137 std::stack<std::pair<int, std::vector<int>::const_iterator>> depth_stack;
138 depth_stack.push(make_pair(root_index, root.children_.begin()));
139 node_seen_count[root_index]++;
141 node_seen_count[root_index] == 1,
144 " visit count must be == 1");
146 while (depth_stack.size() > 0) {
147 auto cur = depth_stack.top();
149 if (cur.second != nodes[cur.first].children_.end()) {
150 int node_index = *cur.second;
151 node_seen_count[node_index]++;
153 depth_stack.push(cur);
154 if (node_seen_count[node_index] == 1) {
157 make_pair(node_index, nodes[node_index].children_.begin()));
165 ExecutionChains chains;
166 std::unordered_set<int> seen_nodes;
167 std::vector<int> chain;
168 std::pair<int, std::vector<int>::const_iterator> cur;
169 std::stack<std::pair<int, std::vector<int>::const_iterator>> depth_stack;
170 auto check_current_for_chaining = [&]() ->
bool {
172 node_seen_count[cur.first] == 1 &&
173 (chain.size() == 0 ||
186 orig_nodes[cur.first].operator_->device_option(),
187 orig_nodes[chain.back()].operator_->device_option()) &&
188 (!orig_nodes[chain.back()].operator_->HasAsyncPart() ||
189 orig_nodes[cur.first].operator_->SupportsAsyncScheduling()))));
191 auto commit_chain = [&]() {
192 if (chain.size() > 0) {
194 chains.insert({chain.front(), chain}).second,
197 " was already added.");
198 VLOG(2) <<
"Added chain: " << chain.front() <<
"with elements";
199 for (
auto ch : chain) {
200 VLOG(2) << ch <<
", ";
205 auto depth_traverse = [&]() {
206 while (cur.second != nodes[cur.first].children_.end() &&
207 seen_nodes.find(*cur.second) != seen_nodes.end()) {
211 if (cur.second != nodes[cur.first].children_.end()) {
212 auto next = make_pair(*cur.second, nodes[*cur.second].children_.begin());
213 depth_stack.push(cur);
214 depth_stack.push(next);
217 for (
int root_index : initial_frontier) {
219 make_pair(root_index, nodes[root_index].children_.begin()));
220 while (depth_stack.size() > 0) {
221 cur = depth_stack.top();
223 if (seen_nodes.find(cur.first) == seen_nodes.end()) {
224 seen_nodes.insert(cur.first);
227 if (nodes[cur.first].children_.size() == 1) {
228 if (check_current_for_chaining()) {
230 VLOG(1) <<
"Adding to existing chain" << cur.first;
231 chain.push_back(cur.first);
232 int index = *nodes[cur.first].children_.begin();
233 depth_stack.push(make_pair(index, nodes[index].children_.begin()));
238 chain.push_back(cur.first);
239 int index = *nodes[cur.first].children_.begin();
240 depth_stack.push(make_pair(index, nodes[index].children_.begin()));
243 nodes[cur.first].children_.size() == 0 &&
244 check_current_for_chaining()) {
246 chain.push_back(cur.first);
253 chain.push_back(cur.first);
269 seen_nodes.size() == nodes.size(),
270 "Haven't seen all the nodes, expected number of nodes ",
276 updateOperatorNodes(orig_nodes, chains);
280 ExecutionChains singleChains(std::vector<OperatorNode>& nodes) {
281 ExecutionChains chains;
282 for (
auto i = 0; i < nodes.size(); ++i) {
285 updateOperatorNodes(nodes, chains);
289 std::vector<OperatorNode> prepareOperatorNodes(
290 const std::shared_ptr<const NetDef>& net_def,
292 std::vector<OperatorNode> operator_nodes(net_def->op_size());
293 std::map<string, int> blob_creator;
294 std::map<string, std::set<int>> blob_readers;
295 bool net_def_has_device_option = net_def->has_device_option();
297 for (
int idx = 0; idx < net_def->op_size(); ++idx) {
298 const OperatorDef& op_def = net_def->op(idx);
299 VLOG(1) <<
"Creating operator #" << idx <<
": " << op_def.name() <<
": " 301 if (!op_def.has_device_option() && net_def_has_device_option) {
302 OperatorDef temp_def(op_def);
303 temp_def.mutable_device_option()->CopyFrom(net_def->device_option());
304 operator_nodes[idx].operator_ = CreateOperator(temp_def, ws, idx);
306 auto op = CreateOperator(op_def, ws, idx);
308 std::shared_ptr<const OperatorDef>{net_def, &(net_def->op(idx))});
309 operator_nodes[idx].operator_ = std::move(op);
314 [&](
const google::protobuf::RepeatedPtrField<std::string>& inputs) {
315 for (
const string& input : inputs) {
316 if (blob_creator.count(input) == 0) {
317 VLOG(1) <<
"Input " << input <<
" not produced by this net. " 318 <<
"Assuming it is pre-existing.";
320 int parent = blob_creator[input];
321 VLOG(1) <<
"op dependency (RaW " << input <<
"): " << parent
323 operator_nodes[idx].parents_.push_back(parent);
324 operator_nodes[parent].children_.push_back(idx);
327 blob_readers[input].insert(idx);
330 checkInputs(op_def.input());
331 checkInputs(op_def.control_input());
334 for (
const string& output : op_def.output()) {
335 if (blob_creator.count(output) != 0) {
338 int waw_parent = blob_creator[output];
339 VLOG(1) <<
"op dependency (WaW " << output <<
"): " << waw_parent
341 operator_nodes[idx].parents_.push_back(waw_parent);
342 operator_nodes[waw_parent].children_.push_back(idx);
346 for (
const int war_parent : blob_readers[output]) {
347 VLOG(1) <<
"op dependency (WaR " << output <<
"): " << war_parent
349 operator_nodes[idx].parents_.push_back(war_parent);
350 operator_nodes[war_parent].children_.push_back(idx);
353 blob_creator[output] = idx;
358 blob_readers[output].clear();
364 for (
int i = 0; i < operator_nodes.size(); ++i) {
365 auto& node = operator_nodes[i];
367 auto& p = node.parents_;
368 std::sort(p.begin(), p.end());
369 p.erase(std::unique(p.begin(), p.end()), p.end());
370 p.erase(std::remove(p.begin(), p.end(), i), p.end());
372 auto& c = node.children_;
373 std::sort(c.begin(), c.end());
374 c.erase(std::unique(c.begin(), c.end()), c.end());
375 c.erase(std::remove(c.begin(), c.end(), i), c.end());
378 return operator_nodes;
381 std::vector<OpGraphNode> prepareChainGraphNodes(
382 const std::vector<dag_utils::OperatorNode>& operator_nodes,
383 const std::vector<std::vector<int>>& execution_chains) {
384 std::unordered_map<int, int> op_to_chain_idx;
385 for (
int chain_idx = 0; chain_idx < execution_chains.size(); ++chain_idx) {
386 const auto& chain_indices = execution_chains[chain_idx];
387 for (
const auto& chain_op_idx : chain_indices) {
388 CAFFE_ENFORCE(!op_to_chain_idx.count(chain_op_idx));
389 op_to_chain_idx[chain_op_idx] = chain_idx;
393 std::vector<OpGraphNode> chain_nodes(execution_chains.size());
394 for (
int op_idx = 0; op_idx < operator_nodes.size(); ++op_idx) {
395 CAFFE_ENFORCE(op_to_chain_idx.count(op_idx));
396 auto chain_idx = op_to_chain_idx[op_idx];
397 auto& chain = chain_nodes[chain_idx];
398 auto& op_node = operator_nodes[op_idx];
400 for (
const auto& child_idx : op_node.children_) {
401 CAFFE_ENFORCE(op_to_chain_idx.count(child_idx));
402 auto child_chain_idx = op_to_chain_idx[child_idx];
403 if (child_chain_idx != chain_idx) {
405 chain.children_.begin(), chain.children_.end(), child_chain_idx);
406 if (it == chain.children_.end()) {
407 chain.children_.push_back(child_chain_idx);
412 for (
const auto& parent_idx : op_node.parents_) {
413 CAFFE_ENFORCE(op_to_chain_idx.count(parent_idx));
414 auto parent_chain_idx = op_to_chain_idx[parent_idx];
415 if (parent_chain_idx != chain_idx) {
417 chain.parents_.begin(), chain.parents_.end(), parent_chain_idx);
418 if (it == chain.parents_.end()) {
419 chain.parents_.push_back(parent_chain_idx);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...