Caffe2 - C++ API
A deep learning, cross platform ML framework
OperatorFusion.cc
1 #include "nomnigraph/Transformations/OperatorFusion.h"
2 
3 #include "nomnigraph/Support/Casting.h"
4 #include "nomnigraph/Support/Pointer.h"
5 
6 namespace nom {
7 namespace transformations {
8 
9 bool fuseConvRelu(repr::NNGraph *g) {
10  for (auto node : g->getMutableNodes()) {
11  // Skip non-operators (tensors and supplementary nodes).
12  if (!isa<repr::NeuralNetOperator>(node->data())) {
13  continue;
14  }
15 
16  // Conv check.
17  if (!isa<repr::Conv>(
18  dyn_cast<repr::NeuralNetOperator>(node->data().get()))) {
19  continue;
20  }
21 
22  // Single output (somewhat redundant).
23  if (node->getOutEdges().size() != 1) {
24  continue;
25  }
26 
27  // Single user check.
28  auto *tensorNode = node->getOutEdges()[0]->head();
29  if (tensorNode->getOutEdges().size() != 1) {
30  continue;
31  }
32 
33  // Followed by Relu check.
34  auto *nextNode = tensorNode->getOutEdges()[0]->head();
35  if (!isa<repr::Relu>(
36  dyn_cast<repr::NeuralNetOperator>(nextNode->data().get()))) {
37  continue;
38  }
39 
40  // Now we do the swap.
41  auto *convNode = node;
42  auto *reluNode = nextNode;
43 
44  // TODO make this a little safer, static_cast is messy.
45  auto conv = static_cast<repr::Conv *>(convNode->mutableData()->release());
46 
47  // Seize ownership of the conv node's data
48  auto *convReluNode =
49  g->createNode(util::make_unique<repr::ConvRelu>(std::move(conv)));
50 
51  for (const auto &inEdge : convNode->getInEdges()) {
52  auto *parent = inEdge->tail();
53  g->createEdge(parent, convReluNode);
54  }
55  for (const auto &outEdge : reluNode->getOutEdges()) {
56  auto *child = outEdge->head();
57  g->createEdge(convReluNode, child);
58  }
59 
60  g->deleteNode(convNode);
61  g->deleteNode(tensorNode);
62  g->deleteNode(reluNode);
63 
64  return true;
65  }
66  return false;
67 }
68 
69 } // namespace transformations
70 } // namespace nom
Definition: Caffe2.cc:16