Caffe2 - C++ API
A deep learning, cross platform ML framework
test.cc
1 #include "nomnigraph/Graph/Algorithms.h"
2 #include "nomnigraph/Graph/Graph.h"
3 
4 #include "nomnigraph/Converters/Caffe2.h"
5 #include "nomnigraph/Converters/Dot.h"
6 
7 #include "nomnigraph/Transformations/ConnectNet.h"
8 #include "nomnigraph/Transformations/OperatorFusion.h"
9 #include "nomnigraph/Transformations/Match.h"
10 
11 #include "nomnigraph/Support/Casting.h"
12 
13 #include <fstream>
14 #include <iomanip>
15 #include <stdio.h>
16 
17 #define ADD_ARG(_op, _name, _type, _val) \
18  { \
19  caffe2::Argument *arg = _op->add_arg(); \
20  arg->set_name(_name); \
21  arg->set_##_type(_val); \
22  }
23 
24 class TestClass {
25 public:
26  TestClass() {}
27  ~TestClass() {}
28 };
29 
30 struct NNEquality {
31  static bool equal(
32  const typename nom::repr::NNGraph::NodeRef& a,
33  const typename nom::repr::NNGraph::NodeRef& b) {
34  if (
35  !nom::repr::nn::is<nom::repr::NeuralNetOperator>(a) ||
36  !nom::repr::nn::is<nom::repr::NeuralNetOperator>(b)) {
37  return false;
38  }
39  auto a_ = nom::repr::nn::get<nom::repr::NeuralNetOperator>(a);
40  auto b_ = nom::repr::nn::get<nom::repr::NeuralNetOperator>(b);
41 
42  bool sameKind = a_->getKind() == b_->getKind();
43  if (sameKind && a_->getKind() == nom::repr::NeuralNetOperator::NNKind::GenericOperator) {
44  return a_->getName() == b_->getName();
45  }
46  return sameKind;
47  }
48 };
49 
50 
51 auto bbprinter = [](typename nom::repr::NNCFGraph::NodeRef node) {
52  std::map<std::string, std::string> labelMap;
53  assert(node->data() && "Node doesn't have data, can't render it");
54  auto *bb = dyn_cast<nom::repr::BasicBlockType<nom::repr::NNGraph>>(
55  node->data().get());
56  labelMap["label"] = std::to_string((unsigned long long)node) + "\\n";
57  for (const auto &instr : bb->getInstructions()) {
58  assert(isa<nom::repr::NeuralNetOperator>(instr->data()) &&
59  "Invalid instruction.");
60  auto *op = dyn_cast<nom::repr::NeuralNetOperator>(instr->data().get());
61  bool hasOutput = false;
62  for (const auto &outEdge : instr->getOutEdges()) {
63  auto *output =
64  dyn_cast<nom::repr::NeuralNetData>(outEdge->head()->data().get());
65  labelMap["label"] += " " + output->getName();
66  hasOutput = true;
67  }
68  if (hasOutput) {
69  labelMap["label"] += " = ";
70  }
71  labelMap["label"] += op->getName();
72  for (const auto &inEdge : instr->getInEdges()) {
73  auto *arg =
74  dyn_cast<nom::repr::NeuralNetData>(inEdge->tail()->data().get());
75  labelMap["label"] += " " + arg->getName();
76  }
77  labelMap["label"] += "\\l";
78  }
79  labelMap["shape"] = "box";
80  return labelMap;
81 };
82 
83 auto cfgedgeprinter = [](typename nom::repr::NNCFGraph::EdgeRef edge) {
84  std::map<std::string, std::string> labelMap;
85  if (edge->data() == -1) {
86  labelMap["label"] = "F";
87  } else if (edge->data() == 1) {
88  labelMap["label"] = "T";
89  }
90  return labelMap;
91 };
92 
93 auto nnprinter = [](typename nom::repr::NNGraph::NodeRef node) {
94  std::map<std::string, std::string> labelMap;
95  assert(node->data() && "Node doesn't have data, can't render it");
96  if (isa<nom::repr::NeuralNetOperator>(node->data())) {
97  auto *op = dyn_cast<nom::repr::NeuralNetOperator>(node->data().get());
98  labelMap["label"] =
99  op->getName() + " (" + std::to_string((unsigned long long)node) + ")";
100  auto *annotation = op->getAnnotation();
101  if (annotation && isa<nom::repr::DeviceAnnotation>(annotation)) {
102  auto device_annotation =
103  dyn_cast<nom::repr::DeviceAnnotation>(annotation);
104  labelMap["label"] += "\\n[" + device_annotation->getDevice() + "]";
105  auto hash = std::hash<std::string>{}(device_annotation->getDevice());
106  std::stringstream hex_stream;
107  hex_stream << std::hex << hash;
108  labelMap["color"] = "#" + hex_stream.str().substr(0, 6);
109  labelMap["fontcolor"] = labelMap["color"];
110  }
111  labelMap["shape"] = "box";
112  } else if (isa<nom::repr::Data>(node->data())) {
113  auto tensor = dyn_cast<nom::repr::NeuralNetData>(node->data().get());
114  labelMap["label"] = tensor->getName();
115  labelMap["label"] += "_" + std::to_string(tensor->getVersion()) + " " + std::to_string((unsigned long long)node);
116  }
117  return labelMap;
118 };
119 
120 int main(int argc, char *argv[]) {
121  {
122  TestClass t1;
123  TestClass t2;
125  nom::Graph<TestClass>::NodeRef n1 = g.createNode(std::move(t1));
126  nom::Graph<TestClass>::NodeRef n2 = g.createNode(std::move(t2));
127  g.createEdge(n1, n2);
128  }
129 
130  {
131  TestClass t1;
132  TestClass t2;
134  nom::Graph<TestClass>::NodeRef n1 = g.createNode(std::move(t1));
135  nom::Graph<TestClass>::NodeRef n2 = g.createNode(std::move(t2));
136  g.createEdge(n1, n2);
137  g.deleteNode(n1);
138  }
139 
140  {
141  TestClass t1;
142  TestClass t2;
144  nom::Graph<TestClass>::NodeRef n1 = g.createNode(std::move(t1));
145  nom::Graph<TestClass>::NodeRef n2 = g.createNode(std::move(t2));
147  g.deleteEdge(e);
148  }
149 
150  {
151  TestClass t1;
152  TestClass t2;
154  nom::Graph<TestClass, int>::NodeRef n1 = g.createNode(std::move(t1));
155  nom::Graph<TestClass, int>::NodeRef n2 = g.createNode(std::move(t2));
156  g.createEdge(n1, n2);
157  g.createEdge(n2, n1);
158  auto tarjans = nom::algorithm::Tarjans<TestClass, int>(&g);
159  auto sccs = tarjans.run();
160  }
161 
162  {
164  std::vector<nom::Graph<TestClass, int>::NodeRef> nodes;
165  for (auto i = 0; i < 10; ++i) {
166  TestClass t;
167  nodes.emplace_back(g.createNode(std::move(t)));
168  }
169  for (auto i = 0; i < 30; ++i) {
170  int ri1 = rand() % nodes.size();
171  int ri2 = rand() % nodes.size();
172  g.createEdge(nodes[ri1], nodes[ri2]);
173  }
174 
175  auto tarjans = nom::algorithm::Tarjans<TestClass, int>(&g);
176  auto sccs = tarjans.run();
177  }
178 
179  {
180  caffe2::NetDef net;
181  for (auto i = 0; i < 10; ++i) {
182  if (rand() % 2) {
183  caffe2::OperatorDef *def = net.add_op();
184  def->set_type("Conv");
185  def->add_input("X");
186  def->add_input("W" + std::to_string(i)); // different weights
187  ADD_ARG(def, "kernel", i, 3);
188  ADD_ARG(def, "stride", i, 1);
189  ADD_ARG(def, "pad", i, 0);
190  ADD_ARG(def, "order", s, "NCHW");
191  def->add_output("X");
192  def->mutable_device_option()->set_node_name("conv_runner");
193  } else {
194  caffe2::OperatorDef *def = net.add_op();
195  def->set_type("Relu");
196  def->add_input("X");
197  def->add_output("X");
198  def->mutable_device_option()->set_node_name("relu_runner");
199  }
200  }
201  auto nn = nom::converters::convertFromCaffe2Proto(net);
202  nom::repr::NNGraph g = std::move(nn.dataFlow);
203  nom::repr::NNCFGraph cfg = std::move(nn.controlFlow);
204 
205  std::ofstream out("unfusedNet.dot");
206  out << nom::converters::convertToDotString(&g, nnprinter);
207  out.close();
208 
209  while (nom::transformations::fuseConvRelu(&g))
210  ;
211 
212  std::ofstream out2("fusedNet.dot");
213  out2 << nom::converters::convertToDotString(&g, nnprinter);
214  out2.close();
215  }
216  {
217  caffe2::NetDef net;
218  for (auto i = 0; i < 10; ++i) {
219  if (i % 2) {
220  caffe2::OperatorDef *def = net.add_op();
221  def->set_type("Conv");
222  def->add_input("X" + std::to_string(i));
223  def->add_input("W" + std::to_string(i)); // different weights
224  def->add_input("b" + std::to_string(i)); // different biases
225  ADD_ARG(def, "kernel", i, 3);
226  ADD_ARG(def, "stride", i, 1);
227  ADD_ARG(def, "pad", i, 0);
228  ADD_ARG(def, "order", s, "NCHW");
229  def->add_output("X" + std::to_string(i+1));
230  def->mutable_device_option()->set_node_name("device_" +
231  std::to_string(rand() % 2));
232  } else {
233  caffe2::OperatorDef *def = net.add_op();
234  def->set_type("Relu");
235  def->add_input("X" + std::to_string(i));
236  def->add_output("X" + std::to_string(i+1));
237  def->mutable_device_option()->set_node_name("device_" +
238  std::to_string(rand() % 2));
239  }
240  }
241  auto nn = nom::converters::convertFromCaffe2Proto(net);
242 
243  std::string dot1 = nom::converters::convertToDotString(&nn.dataFlow, nnprinter);
244  std::ofstream out1("disconnectedNet.dot");
245  out1 << dot1;
246  out1.close();
247 
248  assert(nom::transformations::connectNet(&nn.dataFlow));
249  nom::repr::nn::coalesceInsertedDataDependencies(&nn);
250  {
251  std::string dot = nom::converters::convertToDotString(&nn.dataFlow, nnprinter);
252  std::ofstream out("connectedNet.dot");
253  out << dot;
254  out.close();
255  }
256  {
257  std::string dot = nom::converters::convertToDotString(&nn.controlFlow, bbprinter);
258  std::ofstream out("connectedNet_cfg.dot");
259  out << dot;
260  out.close();
261  }
262  }
263  {
264  caffe2::NetDef net;
265 
266  caffe2::OperatorDef *def = net.add_op();
267  def->set_type("NeverSeen");
268  def->add_input("X");
269  def->add_output("X");
270  def->mutable_device_option()->set_node_name("device_" +
271  std::to_string(rand() % 2));
272  auto nn = nom::converters::convertFromCaffe2Proto(net);
273 
274  auto dot_str =
275  nom::converters::convertToDotString(&nn.dataFlow, nnprinter).c_str();
276  auto new_netdef = nom::converters::convertToCaffe2Proto(nn);
277  }
278 
279  {
281  std::vector<nom::Graph<TestClass, int>::NodeRef> nodes;
282  for (auto i = 0; i < 100; ++i) {
283  TestClass t;
284  nodes.emplace_back(g.createNode(std::move(t)));
285  }
286  for (auto i = 0; i < 200; ++i) {
287  int ri1 = rand() % nodes.size();
288  int ri2 = rand() % nodes.size();
289  g.createEdge(nodes[ri1], nodes[ri2]);
290  }
291 
292  auto sccs = nom::algorithm::tarjans(&g);
293 
294  std::string dot = nom::converters::convertToDotString(
295  &g, sccs, [](typename nom::Graph<TestClass, int>::NodeRef node) {
296  std::map<std::string, std::string> labelMap;
297  labelMap["label"] = std::to_string((unsigned long long)node);
298  return labelMap;
299  });
300 
301  std::ofstream out("sccs.dot");
302  out << dot;
303  out.close();
304  }
305 
306  {
307  caffe2::NetDef net;
308 
309  caffe2::OperatorDef *def = net.add_op();
310  def->set_type("While");
311  def->add_input("X");
312 
313  caffe2::NetDef body_net;
314  {
315  caffe2::OperatorDef *rdef = body_net.add_op();
316  rdef->set_type("Relu");
317  rdef->add_input("X");
318  rdef->add_output("X");
319  }
320  std::string body_net_serialized;
321  assert(body_net.SerializeToString(&body_net_serialized));
322  ADD_ARG(def, "body", s, body_net_serialized);
323 
324  auto nn = nom::converters::convertFromCaffe2Proto(net);
325  nom::repr::NNGraph g = std::move(nn.dataFlow);
326  nom::repr::NNCFGraph cfg = std::move(nn.controlFlow);
327  auto dot = nom::converters::convertToDotString(&g, nnprinter);
328  std::ofstream out("while.dot");
329  out << dot;
330  out.close();
331  }
332  {
333  caffe2::NetDef net;
334 
335  {
336  caffe2::OperatorDef *rdef = net.add_op();
337  rdef->set_type("Relu");
338  rdef->add_input("X");
339  rdef->add_output("X");
340  }
341 
342  caffe2::OperatorDef *def = net.add_op();
343  def->set_type("While");
344  def->add_input("X");
345 
346  caffe2::NetDef body_net;
347  {
348  caffe2::OperatorDef *rdef = body_net.add_op();
349  rdef->set_type("Instr1");
350  rdef->add_input("X");
351  rdef->add_output("X");
352  }
353  {
354  caffe2::OperatorDef *rdef = body_net.add_op();
355  rdef->set_type("Instr2");
356  rdef->add_input("X");
357  rdef->add_output("X");
358  }
359  {
360  caffe2::OperatorDef *rdef = body_net.add_op();
361  rdef->set_type("Instr3");
362  rdef->add_input("X");
363  rdef->add_output("X");
364  }
365  std::string body_net_serialized;
366  assert(body_net.SerializeToString(&body_net_serialized));
367  ADD_ARG(def, "body", s, body_net_serialized);
368 
369  auto nn = nom::converters::convertFromCaffe2Proto(net);
370  nom::repr::NNGraph g = std::move(nn.dataFlow);
371  nom::repr::NNCFGraph cfg = std::move(nn.controlFlow);
372 
373  }
374  do {
375  if (argc < 2) {
376  printf("Try out ./nomnigraph_test tests/distrib_ads_trainer.pb\n");
377  break;
378  }
379  caffe2::NetDef net;
380  std::fstream input(argv[1]);
381  std::string s(std::istreambuf_iterator<char>(input), {});
382  assert(net.ParseFromString(s) && "Couldn't parse network\n");
383 
384  auto nn = nom::converters::convertFromCaffe2Proto(net);
385  {
386  auto dot = nom::converters::convertToDotString(&nn.dataFlow, nnprinter);
387  std::ofstream out("in.dot");
388  out << dot;
389  out.close();
390  }
391  assert(nom::transformations::connectNet(&nn.dataFlow));
392 
393  {
394  auto dot = nom::converters::convertToDotString(&nn.dataFlow, nnprinter);
395  std::ofstream out("out.dot");
396  out << dot;
397  out.close();
398  }
399  } while (0);
400 
401  {
402  caffe2::NetDef net;
403 
404  {
405  caffe2::OperatorDef *rdef = net.add_op();
406  rdef->set_type("Relu");
407  rdef->add_input("X");
408  rdef->add_output("X");
409  }
410 
411  caffe2::OperatorDef *def = net.add_op();
412  def->set_type("While");
413  def->add_input("X");
414 
415  caffe2::NetDef body_net;
416  {
417  caffe2::OperatorDef *rdef = body_net.add_op();
418  rdef->set_type("Relu");
419  rdef->add_input("X");
420  rdef->add_output("X");
421  }
422  {
423  caffe2::OperatorDef *rdef = body_net.add_op();
424  rdef->set_type("Instr2");
425  rdef->add_input("X");
426  rdef->add_output("X");
427  }
428  {
429  caffe2::OperatorDef *rdef = body_net.add_op();
430  rdef->set_type("Instr3");
431  rdef->add_input("X");
432  rdef->add_output("X");
433  }
434  {
435  caffe2::OperatorDef *rdef = body_net.add_op();
436  rdef->set_type("Instr4");
437  rdef->add_input("X");
438  rdef->add_output("Y");
439  }
440  std::string body_net_serialized;
441  assert(body_net.SerializeToString(&body_net_serialized));
442  ADD_ARG(def, "body", s, body_net_serialized);
443 
444  auto nn = nom::converters::convertFromCaffe2Proto(net);
445 
446  auto sccs = nom::algorithm::tarjans(&nn.dataFlow);
447  auto cfgsccs = nom::algorithm::tarjans(&nn.controlFlow);
448  {
449  std::string dot =
450  nom::converters::convertToDotString(&nn.dataFlow, sccs, nnprinter);
451  std::ofstream out("while2.dot");
452  out << dot;
453  out.close();
454  }
455  {
456  std::string dot =
457  nom::converters::convertToDotString(&nn.controlFlow, cfgsccs, bbprinter);
458  std::ofstream out("while_cfg.dot");
459  out << dot;
460  out.close();
461  }
462  for (auto node : nn.controlFlow.getMutableNodes()) {
463  printf("node addr %llu\n", (unsigned long long)node);
464  }
465  auto domFrontMap = nom::algorithm::dominanceFrontierMap(&nn.controlFlow);
466  for (auto pair : domFrontMap) {
467  for (auto node : pair.second) {
468  printf("%llu - %llu\n", (unsigned long long)pair.first, (unsigned long long)node);
469  }
470  }
471  }
472  {
474  auto r = graph.createNode(std::string("r"));
475  auto a = graph.createNode(std::string("a"));
476  auto b = graph.createNode(std::string("b"));
477  auto c = graph.createNode(std::string("c"));
478  auto d = graph.createNode(std::string("d"));
479  auto e = graph.createNode(std::string("e"));
480  auto f = graph.createNode(std::string("f"));
481  auto g = graph.createNode(std::string("g"));
482  auto l = graph.createNode(std::string("l"));
483  auto h = graph.createNode(std::string("h"));
484  auto i = graph.createNode(std::string("i"));
485  auto j = graph.createNode(std::string("j"));
486  auto k = graph.createNode(std::string("k"));
487  graph.createEdge(r, a);
488  graph.createEdge(r, b);
489  graph.createEdge(r, c);
490  graph.createEdge(c, f);
491  graph.createEdge(c, g);
492  graph.createEdge(g, j);
493  graph.createEdge(g, i);
494  graph.createEdge(f, i);
495  graph.createEdge(i, k);
496  graph.createEdge(k, i);
497  graph.createEdge(k, r);
498  graph.createEdge(a, d);
499  graph.createEdge(b, d);
500  graph.createEdge(b, a);
501  graph.createEdge(b, e);
502  graph.createEdge(d, l);
503  graph.createEdge(l, h);
504  graph.createEdge(h, k);
505  graph.createEdge(h, e);
506  graph.createEdge(e, h);
507 
508  {
509  std::ofstream out("dominatorinput.dot");
510  out << nom::converters::convertToDotString(
511  &graph, [](nom::Graph<std::string>::NodeRef node) {
512  std::map<std::string, std::string> labelMap;
513  labelMap["label"] = node->data();
514  return labelMap;
515  });
516  out.close();
517  }
518 
519  auto tree = nom::algorithm::dominatorTree(&graph, r);
520  {
521  std::ofstream out("dominatoroutput.dot");
522  out << nom::converters::convertToDotString(
523  &tree,
524  [](nom::Graph<nom::Graph<std::string>::NodeRef, int>::NodeRef node) {
525  std::map<std::string, std::string> labelMap;
526  labelMap["label"] = node->data()->data();
527  return labelMap;
528  });
529  out.close();
530  }
531  auto map = nom::algorithm::immediateDominatorMap(&graph, r);
532  assert(map[j] == g);
533  assert(map[g] == c);
534  assert(map[f] == c);
535  assert(map[l] == d);
536  assert(map[a] == r);
537  assert(map[b] == r);
538  assert(map[c] == r);
539  assert(map[d] == r);
540  assert(map[e] == r);
541  assert(map[h] == r);
542  assert(map[i] == r);
543  assert(map[k] == r);
544  auto domFrontMap = nom::algorithm::dominanceFrontierMap(&graph, r);
545  }
546 
547  // https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec04-SSA.pdf
548  // using example on page 24
549  {
551  auto entry = graph.createNode(std::string("entry"));
552  auto n1 = graph.createNode(std::string("1"));
553  auto n2 = graph.createNode(std::string("2"));
554  auto n3 = graph.createNode(std::string("3"));
555  auto n4 = graph.createNode(std::string("4"));
556  auto n5 = graph.createNode(std::string("5"));
557  auto n6 = graph.createNode(std::string("6"));
558  auto n7 = graph.createNode(std::string("7"));
559  auto exit = graph.createNode(std::string("exit"));
560  graph.createEdge(entry, n1);
561  graph.createEdge(n1, n2);
562  graph.createEdge(n1, n5);
563  graph.createEdge(n5, n1);
564  graph.createEdge(n2, n3);
565  graph.createEdge(n2, n4);
566  graph.createEdge(n3, n6);
567  graph.createEdge(n4, n6);
568  graph.createEdge(n6, n7);
569  graph.createEdge(n5, n7);
570  graph.createEdge(n7, exit);
571 
572  auto domFrontMap = nom::algorithm::dominanceFrontierMap(&graph, entry);
573  using noderef = nom::Graph<std::string>::NodeRef;
574  std::unordered_map<noderef, std::unordered_set<noderef>> checkMap = {
575  {n1, {n1}},
576  {n2, {n7}},
577  {n3, {n6}},
578  {n4, {n6}},
579  {n5, {n1, n7}},
580  {n6, {n7}}
581  };
582  for (auto pair : domFrontMap) {
583  assert(pair.second == checkMap[pair.first]);
584  }
585  }
586  // Test modifying the DFG without explicitly modifying the CFG
587  {
588  caffe2::NetDef net;
589  {
590  caffe2::OperatorDef *rdef = net.add_op();
591  rdef->set_type("Instr1");
592  rdef->add_input("X");
593  rdef->add_output("X");
594  }
595  {
596  caffe2::OperatorDef *rdef = net.add_op();
597  rdef->set_type("Instr2");
598  rdef->add_input("X");
599  rdef->add_output("X");
600  }
601  {
602  caffe2::OperatorDef *rdef = net.add_op();
603  rdef->set_type("Instr3");
604  rdef->add_input("X");
605  rdef->add_output("X");
606  }
607  auto nn = nom::converters::convertFromCaffe2Proto(net);
608 
609  {
610  auto dot = nom::converters::convertToDotString(&nn.controlFlow, bbprinter);
611  std::ofstream out("dfg_test_in.dot");
612  out << dot;
613  out.close();
614  }
615 
616  auto randomNode = nn.dataFlow.getMutableNodes()[0];
617  nn.dataFlow.deleteNode(randomNode);
618 
619  {
620  auto dot = nom::converters::convertToDotString(&nn.controlFlow, bbprinter);
621  std::ofstream out("dfg_test_out.dot");
622  out << dot;
623  out.close();
624  }
625 
626  }
627  {
629  auto entry = graph.createNode(std::string("entry"));
630  auto n1 = graph.createNode(std::string("1"));
631  auto n2 = graph.createNode(std::string("2"));
632  auto n3 = graph.createNode(std::string("3"));
633  auto n4 = graph.createNode(std::string("4"));
634  auto n5 = graph.createNode(std::string("5"));
635  auto n6 = graph.createNode(std::string("6"));
636  auto n7 = graph.createNode(std::string("7"));
637  auto exit = graph.createNode(std::string("exit"));
638  graph.createEdge(entry, n1);
639  graph.createEdge(n1, n2);
640  graph.createEdge(n1, n5);
641  graph.createEdge(n5, n1);
642  graph.createEdge(n2, n3);
643  graph.createEdge(n2, n4);
644  graph.createEdge(n3, n6);
645  graph.createEdge(n4, n6);
646  graph.createEdge(n6, n7);
647  graph.createEdge(n5, n7);
648  graph.createEdge(n7, exit);
649 
650  nom::Graph<std::string> match_graph;
651  auto m1 = match_graph.createNode(std::string("1"));
652  auto m2 = match_graph.createNode(std::string("2"));
653  match_graph.createEdge(m1, m2);
654 
655  nom::Match<decltype(graph)> m(match_graph);
656  assert(m.match(graph).size() == 1);
657  }
658 
659  {
660  caffe2::NetDef net;
661  {
662  caffe2::OperatorDef *rdef = net.add_op();
663  rdef->set_type("Instr1");
664  rdef->add_input("X");
665  rdef->add_output("X");
666  }
667  {
668  caffe2::OperatorDef *rdef = net.add_op();
669  rdef->set_type("Instr2");
670  rdef->add_input("X");
671  rdef->add_output("X");
672  }
673  {
674  caffe2::OperatorDef *rdef = net.add_op();
675  rdef->set_type("Instr3");
676  rdef->add_input("X");
677  rdef->add_output("X");
678  }
679  auto nn = nom::converters::convertFromCaffe2Proto(net);
680 
681  caffe2::NetDef matchnet;
682  {
683  caffe2::OperatorDef *rdef = matchnet.add_op();
684  rdef->set_type("Instr1");
685  }
686  auto matchnn = nom::converters::convertFromCaffe2Proto(matchnet);
687  nom::Match<decltype(nn.dataFlow), NNEquality> m(matchnn.dataFlow);
688  assert(m.match(nn.dataFlow).size() == 1);
689  }
690 
691  return 0;
692 }
NodeRef createNode(T &&data)
Creates a node and retains ownership of it.
Definition: Graph.h:164
void deleteNode(NodeRef n, bool deleteEdges=true)
Deletes a node from the graph.
Definition: Graph.h:264
void deleteEdge(EdgeRef e)
Deletes a edge from the graph.
Definition: Graph.h:283
A simple graph implementation.
Definition: Graph.h:30
EdgeRef createEdge(NodeRef tail, NodeRef head)
Creates a directed edge and retains ownership of it.
Definition: Graph.h:230