1 #include "nomnigraph/Transformations/ConnectNet.h" 4 namespace transformations {
8 const std::string getDeviceFromNode(
const NNGraph::NodeRef &node) {
9 auto nnOp = nn::get<NeuralNetOperator>(node);
10 const Annotation *annotation = nnOp->getAnnotation();
11 if (annotation && isa<DeviceAnnotation>(annotation)) {
12 auto device_annotation = dyn_cast<DeviceAnnotation>(annotation);
13 return device_annotation->getDevice();
18 bool connectNet(NNGraph *g) {
20 for (
auto tensor_node_pair : nn::dataIterator<NeuralNetData>(*g)) {
22 NNGraph::NodeRef tensorNode;
23 NeuralNetData* tensor;
24 std::tie(tensor, tensorNode) = tensor_node_pair;
28 if (!nn::hasProducer(tensorNode)) {
continue; }
30 auto producerDevice = getDeviceFromNode(nn::getProducer(tensorNode));
31 for (
auto& consumerNode : nn::getConsumers(tensorNode)) {
33 auto consumerDevice = getDeviceFromNode(consumerNode);
34 if (consumerDevice == producerDevice) {
continue; }
36 auto sendNode = g->createNode(util::make_unique<Send>());
37 g->createEdge(tensorNode, sendNode);
39 auto sendTensorNode = g->createNode(
40 util::make_unique<Tensor>(tensor->getName() +
"_send"));
41 g->createEdge(sendNode, sendTensorNode);
43 auto recvNode = g->createNode(util::make_unique<Receive>());
44 g->createEdge(sendTensorNode, recvNode);
46 auto recvTensorNode = g->createNode(
47 util::make_unique<Tensor>(tensor->getName() +
"_recv"));
48 g->createEdge(recvNode, recvTensorNode);
50 g->createEdge(recvTensorNode, consumerNode);
53 g->deleteEdge(g->getEdge(tensorNode, consumerNode));