1 #include "nomnigraph/Transformations/OperatorFusion.h" 3 #include "nomnigraph/Support/Casting.h" 4 #include "nomnigraph/Support/Pointer.h" 7 namespace transformations {
9 bool fuseConvRelu(repr::NNGraph *g) {
10 for (
auto node : g->getMutableNodes()) {
12 if (!isa<repr::NeuralNetOperator>(node->data())) {
18 dyn_cast<repr::NeuralNetOperator>(node->data().get()))) {
23 if (node->getOutEdges().size() != 1) {
28 auto *tensorNode = node->getOutEdges()[0]->head();
29 if (tensorNode->getOutEdges().size() != 1) {
34 auto *nextNode = tensorNode->getOutEdges()[0]->head();
36 dyn_cast<repr::NeuralNetOperator>(nextNode->data().get()))) {
41 auto *convNode = node;
42 auto *reluNode = nextNode;
45 auto conv =
static_cast<repr::Conv *
>(convNode->mutableData()->release());
49 g->createNode(util::make_unique<repr::ConvRelu>(std::move(conv)));
51 for (
const auto &inEdge : convNode->getInEdges()) {
52 auto *parent = inEdge->tail();
53 g->createEdge(parent, convReluNode);
55 for (
const auto &outEdge : reluNode->getOutEdges()) {
56 auto *child = outEdge->head();
57 g->createEdge(convReluNode, child);
60 g->deleteNode(convNode);
61 g->deleteNode(tensorNode);
62 g->deleteNode(reluNode);