Caffe2 - C++ API
A deep learning, cross platform ML framework
TarjansImpl.h
1 #ifndef NOM_GRAPH_ALGORITHMS_H
2 #error "This should only be included by Graph/Algorithms.h"
3 #endif
4 
5 template <typename T, typename U> struct GraphWrapper {
6  struct NodeWrapper {
7  using NodeRef = typename Graph<T, U>::NodeRef;
8  NodeWrapper(NodeRef n) : node(n) {}
9  NodeRef node;
10  int Index = -1;
11  int LowLink = -1;
12  bool OnStack = false;
13  };
14 
15  struct EdgeWrapper {
16  typename Graph<T, U>::EdgeRef edge;
17  };
18 };
19 
36 template <typename T, typename U = T> class Tarjans {
37  using NodeWrapper = typename GraphWrapper<T, U>::NodeWrapper;
38  using EdgeWrapper = typename GraphWrapper<T, U>::EdgeWrapper;
39  using WrappedGraph = Graph<NodeWrapper, EdgeWrapper>;
40  using WrappedSubgraph = Subgraph<NodeWrapper, EdgeWrapper>;
41 
42 private:
43  int Index = 0;
44  std::vector<typename WrappedGraph::NodeRef> Stack;
45  Graph<T, U> *InputGraph;
46  WrappedGraph WrappedInputGraph;
47  std::vector<WrappedSubgraph> WrappedSCCs;
48 
49 public:
53  Tarjans(Graph<T, U> *g) : InputGraph(g) {
54  // Wrap Graph with node labels
55  std::unordered_map<typename Graph<T, U>::NodeRef,
56  typename WrappedGraph::NodeRef>
57  n_to_wrappedNode;
58 
59  for (const auto &n : InputGraph->getMutableNodes()) {
60  NodeWrapper wrappedNode(n);
61  n_to_wrappedNode[n] =
62  WrappedInputGraph.createNode(std::move(wrappedNode));
63  }
64 
65  for (const auto &e : InputGraph->getMutableEdges()) {
66  EdgeWrapper wrappedEdge = {e};
67  WrappedInputGraph.createEdge(n_to_wrappedNode[e->tail()],
68  n_to_wrappedNode[e->head()],
69  std::move(wrappedEdge));
70  }
71  }
72 
75  void connect(typename WrappedGraph::NodeRef n) {
76  n->mutableData()->Index = Index;
77  n->mutableData()->LowLink = Index;
78  Index++;
79 
80  Stack.emplace_back(n);
81  n->mutableData()->OnStack = true;
82 
83  for (const auto &outEdge : n->getOutEdges()) {
84  typename WrappedGraph::NodeRef newNode = outEdge->head();
85  // Check if we've considered this node before.
86  if (newNode->data().Index == -1) {
87  connect(newNode);
88  n->mutableData()->LowLink =
89  std::min(n->data().LowLink, newNode->data().LowLink);
90  // Check if newNode is in the SCC.
91  } else if (newNode->data().OnStack) {
92  n->mutableData()->LowLink =
93  std::min(n->data().LowLink, newNode->data().Index);
94  }
95  }
96 
97  // If our node is a root node, pop it from the stack (we've found an SCC)
98  if (n->data().LowLink == n->data().Index) {
99  WrappedSubgraph wrappedSCC;
100 
101  typename WrappedGraph::NodeRef w;
102  do {
103  w = Stack.back();
104  w->mutableData()->OnStack = false;
105  Stack.pop_back();
106  wrappedSCC.addNode(w);
107  } while (w != n);
108 
109  // Add all the edges into the SCC.
110  // TODO include edges in the SCC in a smarter way.
111  const auto &sccNodes = wrappedSCC.getNodes();
112  for (const auto &sccNode : sccNodes) {
113  for (const auto &outEdge : sccNode->getOutEdges()) {
114  if (std::find(sccNodes.begin(), sccNodes.end(), outEdge->head()) !=
115  sccNodes.end()) {
116  wrappedSCC.addEdge(outEdge);
117  }
118  }
119  }
120  WrappedSCCs.emplace_back(wrappedSCC);
121  }
122  }
123 
128  inline Subgraph<T, U> unwrapSubgraph(const WrappedSubgraph &wrappedSubgraph) {
129  Subgraph<T, U> s;
130  for (auto wrappedNode : wrappedSubgraph.getNodes()) {
131  s.addNode(wrappedNode->data().node);
132  }
133  for (auto wrappedEdge : wrappedSubgraph.getEdges()) {
134  s.addEdge(wrappedEdge->data().edge);
135  }
136  return s;
137  }
138 
139  std::vector<Subgraph<T, U>> run() {
140  for (auto n : WrappedInputGraph.getMutableNodes()) {
141  if (n->data().Index == -1) {
142  connect(n);
143  }
144  }
145 
146  std::vector<Subgraph<T, U>> sccs;
147  for (auto wrappedSCC : WrappedSCCs) {
148  sccs.emplace_back(unwrapSubgraph(wrappedSCC));
149  }
150 
151  return sccs;
152  }
153 };
154 
156 template <typename T, typename U>
157 std::vector<Subgraph<T, U>> tarjans(Graph<T, U> *g) {
158  Tarjans<T, U> t(g);
159  return t.run();
160 }
Subgraph< T, U > unwrapSubgraph(const WrappedSubgraph &wrappedSubgraph)
Helper function for recovering a valid subgraph output.
Definition: TarjansImpl.h:128
Tarjans(Graph< T, U > *g)
Constructor wraps the input graph with an annotated graph set up with the datastructures needed for t...
Definition: TarjansImpl.h:53
void connect(typename WrappedGraph::NodeRef n)
Helper function for finding strongly connected components.
Definition: TarjansImpl.h:75
Tarjans algorithm implementation.
Definition: TarjansImpl.h:36