1 #include "nomnigraph/Representations/NeuralNet.h" 6 NeuralNetOperator::~NeuralNetOperator() {}
8 const std::string NeuralNetOperator::getName()
const {
22 case NNKind::ConvRelu:
24 case NNKind::DynamicInput:
25 return "DynamicInput";
26 case NNKind::GenericOperator:
27 return dyn_cast<GenericOperator>(
this)->getName();
33 NeuralNetData::~NeuralNetData() {}
35 const std::string NeuralNetData::getName()
const {
37 case NNDataKind::Tensor: {
38 return dyn_cast<
Tensor>(
this)->getName();
47 bool hasProducer(NNGraph::NodeRef n) {
48 return n->getInEdges().size() != 0;
51 NNGraph::NodeRef getProducer(NNGraph::NodeRef n) {
52 assert(is<NeuralNetData>(n) &&
"getProducer only works with NeuralNetData types.");
53 auto inEdges = n->getInEdges();
54 assert(inEdges.size() > 0 &&
"Tensor does not have a producer.");
55 assert(inEdges.size() == 1 &&
"Malformed NNGraph, NeuralNetData has multiple producers.");
56 return inEdges.front()->tail();
59 std::vector<NNGraph::NodeRef> getConsumers(NNGraph::NodeRef n) {
60 assert(is<NeuralNetData>(n) &&
"getProducer only works with NeuralNetData types.");
61 std::vector<NNGraph::NodeRef> out;
62 for (
auto outEdge : n->getOutEdges()) {
63 out.emplace_back(outEdge->head());
68 bool hasInputs(NNGraph::NodeRef n) {
69 return n->getInEdges().size() != 0;
72 std::vector<NNGraph::NodeRef> getInputs(NNGraph::NodeRef n) {
73 assert(is<NeuralNetOperator>(n) &&
"getInputs only works with NeuralNetOperator types.");
74 std::vector<NNGraph::NodeRef> out;
75 for (
auto inEdge : n->getInEdges()) {
76 out.emplace_back(inEdge->tail());
81 std::vector<NNGraph::NodeRef> getOutputs(NNGraph::NodeRef n) {
82 assert(is<NeuralNetOperator>(n) &&
"getOutputs only works with NeuralNetOperator types.");
83 std::vector<NNGraph::NodeRef> out;
84 for (
auto outEdge : n->getOutEdges()) {
85 out.emplace_back(outEdge->head());
90 size_t coalesceInsertedDataDependenciesHelper(repr::NNModule* m) {
92 std::unordered_set<repr::NNGraph::NodeRef> cfTrackedNodes;
93 for (
const auto &bbNode : m->controlFlow.getMutableNodes()) {
94 auto bb = repr::nn::get<repr::BasicBlockType<repr::NNGraph>>(bbNode);
95 for (
const auto node : bb->getInstructions()) {
96 cfTrackedNodes.insert(node);
100 for (
auto &bbNode : m->controlFlow.getMutableNodes()) {
101 auto bb = repr::nn::get<repr::BasicBlockType<repr::NNGraph>>(bbNode);
104 auto instrsCopy = bb->getInstructions();
105 for (
const auto instr : instrsCopy) {
106 for (
const auto input : repr::nn::getInputs(instr)) {
107 if (!repr::nn::hasProducer(input)) {
continue; }
108 auto producer = repr::nn::getProducer(input);
109 if (!cfTrackedNodes.count(producer)) {
110 bb->insertInstructionBefore(producer, instr);
111 cfTrackedNodes.insert(producer);
117 return cfTrackedNodes.size();
122 void coalesceInsertedDataDependencies(repr::NNModule* m) {
127 newSize = coalesceInsertedDataDependenciesHelper(m);
128 }
while (newSize != oldSize);