Caffe2 - C++ API
A deep learning, cross platform ML framework
NeuralNet.cc
1 #include "nomnigraph/Representations/NeuralNet.h"
2 
3 namespace nom {
4 namespace repr {
5 
6 NeuralNetOperator::~NeuralNetOperator() {}
7 
8 const std::string NeuralNetOperator::getName() const {
9  switch (getKind()) {
10  case NNKind::Conv:
11  return "Conv";
12  case NNKind::Relu:
13  return "Relu";
14  case NNKind::Send:
15  return "Send";
16  case NNKind::Receive:
17  return "Receive";
18  case NNKind::While:
19  return "While";
20  case NNKind::NNPhi:
21  return "Phi";
22  case NNKind::ConvRelu:
23  return "ConvRelu";
24  case NNKind::DynamicInput:
25  return "DynamicInput";
26  case NNKind::GenericOperator:
27  return dyn_cast<GenericOperator>(this)->getName();
28  default:
29  return "Unknown";
30  }
31 }
32 
33 NeuralNetData::~NeuralNetData() {}
34 
35 const std::string NeuralNetData::getName() const {
36  switch (getKind()) {
37  case NNDataKind::Tensor: {
38  return dyn_cast<Tensor>(this)->getName();
39  }
40  default:
41  return "";
42  }
43 }
44 
45 namespace nn {
46 
47 bool hasProducer(NNGraph::NodeRef n) {
48  return n->getInEdges().size() != 0;
49 }
50 
51 NNGraph::NodeRef getProducer(NNGraph::NodeRef n) {
52  assert(is<NeuralNetData>(n) && "getProducer only works with NeuralNetData types.");
53  auto inEdges = n->getInEdges();
54  assert(inEdges.size() > 0 && "Tensor does not have a producer.");
55  assert(inEdges.size() == 1 && "Malformed NNGraph, NeuralNetData has multiple producers.");
56  return inEdges.front()->tail();
57 }
58 
59 std::vector<NNGraph::NodeRef> getConsumers(NNGraph::NodeRef n) {
60  assert(is<NeuralNetData>(n) && "getProducer only works with NeuralNetData types.");
61  std::vector<NNGraph::NodeRef> out;
62  for (auto outEdge : n->getOutEdges()) {
63  out.emplace_back(outEdge->head());
64  }
65  return out;
66 }
67 
68 bool hasInputs(NNGraph::NodeRef n) {
69  return n->getInEdges().size() != 0;
70 }
71 
72 std::vector<NNGraph::NodeRef> getInputs(NNGraph::NodeRef n) {
73  assert(is<NeuralNetOperator>(n) && "getInputs only works with NeuralNetOperator types.");
74  std::vector<NNGraph::NodeRef> out;
75  for (auto inEdge : n->getInEdges()) {
76  out.emplace_back(inEdge->tail());
77  }
78  return out;
79 }
80 
81 std::vector<NNGraph::NodeRef> getOutputs(NNGraph::NodeRef n) {
82  assert(is<NeuralNetOperator>(n) && "getOutputs only works with NeuralNetOperator types.");
83  std::vector<NNGraph::NodeRef> out;
84  for (auto outEdge : n->getOutEdges()) {
85  out.emplace_back(outEdge->head());
86  }
87  return out;
88 }
89 
90 size_t coalesceInsertedDataDependenciesHelper(repr::NNModule* m) {
91  // Get all nodes tracked by CF graph
92  std::unordered_set<repr::NNGraph::NodeRef> cfTrackedNodes;
93  for (const auto &bbNode : m->controlFlow.getMutableNodes()) {
94  auto bb = repr::nn::get<repr::BasicBlockType<repr::NNGraph>>(bbNode);
95  for (const auto node : bb->getInstructions()) {
96  cfTrackedNodes.insert(node);
97  }
98  }
99 
100  for (auto &bbNode : m->controlFlow.getMutableNodes()) {
101  auto bb = repr::nn::get<repr::BasicBlockType<repr::NNGraph>>(bbNode);
102  // We mutate the instructions of the bb, so we copy here.
103  // TODO make this an iterator and simply promote it on insertion.
104  auto instrsCopy = bb->getInstructions();
105  for (const auto instr : instrsCopy) {
106  for (const auto input : repr::nn::getInputs(instr)) {
107  if (!repr::nn::hasProducer(input)) { continue; }
108  auto producer = repr::nn::getProducer(input);
109  if (!cfTrackedNodes.count(producer)) {
110  bb->insertInstructionBefore(producer, instr);
111  cfTrackedNodes.insert(producer);
112  }
113  }
114  }
115  }
116 
117  return cfTrackedNodes.size();
118 }
119 
120 // TODO: move this to more generic location.
121 // TODO: [algo] improve this algorithm, as it is horrendously inefficient.
122 void coalesceInsertedDataDependencies(repr::NNModule* m) {
123  size_t oldSize = 0;
124  size_t newSize = 0;
125  do {
126  oldSize = newSize;
127  newSize = coalesceInsertedDataDependenciesHelper(m);
128  } while (newSize != oldSize);
129 }
130 
131 } // namespace nn
132 
133 } // namespace repr
134 } // namespace nom
Definition: Caffe2.cc:16