Caffe2 - C++ API
A deep learning, cross platform ML framework
Match.h
1 //=== nomnigraph/Transformations/Match.h - Graph matching utils -*- C++ -*-===//
2 //
3 // TODO Licensing.
4 //
5 //===----------------------------------------------------------------------===//
6 //
7 // This file defines utilities for matching subgraphs.
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #ifndef NOM_TRANFORMATIONS_MATCH_H
12 #define NOM_TRANFORMATIONS_MATCH_H
13 
14 #include "nomnigraph/Graph/Algorithms.h"
15 
16 #include <algorithm>
17 #include <vector>
18 
19 namespace nom {
20 
21 template <typename T>
23  static bool equal(const T& a, const T& b) {
24  return a->data() == b->data();
25  }
26 };
27 
28 template <typename G, typename EqualityClass = NodeEqualityDefault<typename G::NodeRef>>
29 class Match {
30 public:
32 
33  Match(G& g) : MatchGraph(g) {
34  // First we sort both the matching graph topologically.
35  // This could give us a useful anchor in the best case.
36  auto topoMatch = nom::algorithm::tarjans(&MatchGraph);
37  for (auto scc : topoMatch) {
38  for (auto node : scc.getNodes()) {
39  MatchNodeList.emplace_back(node);
40  }
41  }
42  std::reverse(MatchNodeList.begin(), MatchNodeList.end());
43  }
44 
45  std::vector<SubgraphType> recursiveMatch(typename G::NodeRef candidateNode, std::vector<typename G::NodeRef> stack, SubgraphType currentSubgraph) {
46  if (EqualityClass::equal(stack.back(), candidateNode)) {
47  currentSubgraph.addNode(candidateNode);
48 
49  // Base case
50  if (stack.size() == MatchNodeList.size()) {
51  return std::vector<SubgraphType>{currentSubgraph};
52  }
53 
54  // Recurse and accumulate matches
55  stack.emplace_back(MatchNodeList.at(stack.size()));
56 
57  std::vector<SubgraphType> matchingSubgraphs;
58  for (auto outEdge : candidateNode->getOutEdges()) {
59  for (auto subgraph : recursiveMatch(outEdge->head(), stack, currentSubgraph)) {
60  matchingSubgraphs.emplace_back(subgraph);
61  }
62  }
63  return matchingSubgraphs;
64  }
65 
66  // No match here, early bailout
67  return std::vector<SubgraphType>{};
68  }
69 
70  std::vector<SubgraphType> match(G& g) {
71  std::vector<SubgraphType> out;
72 
73  std::vector<typename G::NodeRef> stack;
74  stack.emplace_back(MatchNodeList.front());
75 
76  // Try each node in the candidate graph as the anchor.
77  for (auto n : g.getMutableNodes()) {
78  for (auto subgraph : recursiveMatch(n, stack, SubgraphType())) {
79  out.emplace_back(subgraph);
80  }
81  }
82 
83  return out;
84  }
85 
86 private:
87  G& MatchGraph;
88  std::vector<typename G::NodeRef> MatchNodeList;
89 };
90 
91 }
92 
93 #endif // NOM_TRANFORMATIONS_MATCH_H
Effectively a constant reference to a graph.
Definition: Graph.h:110
Definition: Caffe2.cc:16