Caffe2 - C++ API
A deep learning, cross platform ML framework
ConnectNet.cc
1 #include "nomnigraph/Transformations/ConnectNet.h"
2 
3 namespace nom {
4 namespace transformations {
5 
6 using namespace repr;
7 
8 const std::string getDeviceFromNode(const NNGraph::NodeRef &node) {
9  auto nnOp = nn::get<NeuralNetOperator>(node);
10  const Annotation *annotation = nnOp->getAnnotation();
11  if (annotation && isa<DeviceAnnotation>(annotation)) {
12  auto device_annotation = dyn_cast<DeviceAnnotation>(annotation);
13  return device_annotation->getDevice();
14  }
15  return "";
16 }
17 
18 bool connectNet(NNGraph *g) {
19  // Iterate through all the tensors in the graph.
20  for (auto tensor_node_pair : nn::dataIterator<NeuralNetData>(*g)) {
21 
22  NNGraph::NodeRef tensorNode;
23  NeuralNetData* tensor;
24  std::tie(tensor, tensorNode) = tensor_node_pair;
25 
26  // This is an edge case for when a tensor is created from outside
27  // the execution graph.
28  if (!nn::hasProducer(tensorNode)) { continue; }
29 
30  auto producerDevice = getDeviceFromNode(nn::getProducer(tensorNode));
31  for (auto& consumerNode : nn::getConsumers(tensorNode)) {
32 
33  auto consumerDevice = getDeviceFromNode(consumerNode);
34  if (consumerDevice == producerDevice) { continue; }
35 
36  auto sendNode = g->createNode(util::make_unique<Send>());
37  g->createEdge(tensorNode, sendNode);
38 
39  auto sendTensorNode = g->createNode(
40  util::make_unique<Tensor>(tensor->getName() + "_send"));
41  g->createEdge(sendNode, sendTensorNode);
42 
43  auto recvNode = g->createNode(util::make_unique<Receive>());
44  g->createEdge(sendTensorNode, recvNode);
45 
46  auto recvTensorNode = g->createNode(
47  util::make_unique<Tensor>(tensor->getName() + "_recv"));
48  g->createEdge(recvNode, recvTensorNode);
49 
50  g->createEdge(recvTensorNode, consumerNode);
51 
52  // This is safe because we copied the edge list.
53  g->deleteEdge(g->getEdge(tensorNode, consumerNode));
54  }
55  }
56  return true;
57 }
58 
59 } // namespace transformations
60 } // namespace nom
Definition: Caffe2.cc:16