Caffe2 - Python API
A deep learning, cross platform ML framework
net_drawer.py
1 ## @package net_drawer
2 # Module caffe2.python.net_drawer
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import argparse
8 import json
9 import logging
10 from collections import defaultdict
11 from caffe2.python import utils
12 from future.utils import viewitems
13 
14 logger = logging.getLogger(__name__)
15 logger.setLevel(logging.INFO)
16 
17 try:
18  import pydot
19 except ImportError:
20  logger.info(
21  'Cannot import pydot, which is required for drawing a network. This '
22  'can usually be installed in python with "pip install pydot". Also, '
23  'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
24  'can usually be installed with "sudo apt-get install graphviz".'
25  )
26  print(
27  'net_drawer will not run correctly. Please install the correct '
28  'dependencies.'
29  )
30  pydot = None
31 
32 from caffe2.proto import caffe2_pb2
33 
34 OP_STYLE = {
35  'shape': 'box',
36  'color': '#0F9D58',
37  'style': 'filled',
38  'fontcolor': '#FFFFFF'
39 }
40 BLOB_STYLE = {'shape': 'octagon'}
41 
42 
43 def _rectify_operator_and_name(operators_or_net, name):
44  """Gets the operators and name for the pydot graph."""
45  if isinstance(operators_or_net, caffe2_pb2.NetDef):
46  operators = operators_or_net.op
47  if name is None:
48  name = operators_or_net.name
49  elif hasattr(operators_or_net, 'Proto'):
50  net = operators_or_net.Proto()
51  if not isinstance(net, caffe2_pb2.NetDef):
52  raise RuntimeError(
53  "Expecting NetDef, but got {}".format(type(net)))
54  operators = net.op
55  if name is None:
56  name = net.name
57  else:
58  operators = operators_or_net
59  if name is None:
60  name = "unnamed"
61  return operators, name
62 
63 
64 def _escape_label(name):
65  # json.dumps is poor man's escaping
66  return json.dumps(name)
67 
68 
69 def GetOpNodeProducer(append_output, **kwargs):
70  def ReallyGetOpNode(op, op_id):
71  if op.name:
72  node_name = '%s/%s (op#%d)' % (op.name, op.type, op_id)
73  else:
74  node_name = '%s (op#%d)' % (op.type, op_id)
75  if append_output:
76  for output_name in op.output:
77  node_name += '\n' + output_name
78  return pydot.Node(node_name, **kwargs)
79  return ReallyGetOpNode
80 
81 
82 def GetPydotGraph(
83  operators_or_net,
84  name=None,
85  rankdir='LR',
86  node_producer=None
87 ):
88  if node_producer is None:
89  node_producer = GetOpNodeProducer(False, **OP_STYLE)
90  operators, name = _rectify_operator_and_name(operators_or_net, name)
91  graph = pydot.Dot(name, rankdir=rankdir)
92  pydot_nodes = {}
93  pydot_node_counts = defaultdict(int)
94  for op_id, op in enumerate(operators):
95  op_node = node_producer(op, op_id)
96  graph.add_node(op_node)
97  # print 'Op: %s' % op.name
98  # print 'inputs: %s' % str(op.input)
99  # print 'outputs: %s' % str(op.output)
100  for input_name in op.input:
101  if input_name not in pydot_nodes:
102  input_node = pydot.Node(
103  _escape_label(
104  input_name + str(pydot_node_counts[input_name])),
105  label=_escape_label(input_name),
106  **BLOB_STYLE
107  )
108  pydot_nodes[input_name] = input_node
109  else:
110  input_node = pydot_nodes[input_name]
111  graph.add_node(input_node)
112  graph.add_edge(pydot.Edge(input_node, op_node))
113  for output_name in op.output:
114  if output_name in pydot_nodes:
115  # we are overwriting an existing blob. need to updat the count.
116  pydot_node_counts[output_name] += 1
117  output_node = pydot.Node(
118  _escape_label(
119  output_name + str(pydot_node_counts[output_name])),
120  label=_escape_label(output_name),
121  **BLOB_STYLE
122  )
123  pydot_nodes[output_name] = output_node
124  graph.add_node(output_node)
125  graph.add_edge(pydot.Edge(op_node, output_node))
126  return graph
127 
128 
129 def GetPydotGraphMinimal(
130  operators_or_net,
131  name=None,
132  rankdir='LR',
133  minimal_dependency=False,
134  node_producer=None,
135 ):
136  """Different from GetPydotGraph, hide all blob nodes and only show op nodes.
137 
138  If minimal_dependency is set as well, for each op, we will only draw the
139  edges to the minimal necessary ancestors. For example, if op c depends on
140  op a and b, and op b depends on a, then only the edge b->c will be drawn
141  because a->c will be implied.
142  """
143  if node_producer is None:
144  node_producer = GetOpNodeProducer(False, **OP_STYLE)
145  operators, name = _rectify_operator_and_name(operators_or_net, name)
146  graph = pydot.Dot(name, rankdir=rankdir)
147  # blob_parents maps each blob name to its generating op.
148  blob_parents = {}
149  # op_ancestry records the ancestors of each op.
150  op_ancestry = defaultdict(set)
151  for op_id, op in enumerate(operators):
152  op_node = node_producer(op, op_id)
153  graph.add_node(op_node)
154  # Get parents, and set up op ancestry.
155  parents = [
156  blob_parents[input_name] for input_name in op.input
157  if input_name in blob_parents
158  ]
159  op_ancestry[op_node].update(parents)
160  for node in parents:
161  op_ancestry[op_node].update(op_ancestry[node])
162  if minimal_dependency:
163  # only add nodes that do not have transitive ancestry
164  for node in parents:
165  if all(
166  [node not in op_ancestry[other_node]
167  for other_node in parents]
168  ):
169  graph.add_edge(pydot.Edge(node, op_node))
170  else:
171  # Add all parents to the graph.
172  for node in parents:
173  graph.add_edge(pydot.Edge(node, op_node))
174  # Update blob_parents to reflect that this op created the blobs.
175  for output_name in op.output:
176  blob_parents[output_name] = op_node
177  return graph
178 
179 
180 def GetOperatorMapForPlan(plan_def):
181  operator_map = {}
182  for net_id, net in enumerate(plan_def.network):
183  if net.HasField('name'):
184  operator_map[plan_def.name + "_" + net.name] = net.op
185  else:
186  operator_map[plan_def.name + "_network_%d" % net_id] = net.op
187  return operator_map
188 
189 
190 def _draw_nets(nets, g):
191  nodes = []
192  for i, net in enumerate(nets):
193  nodes.append(pydot.Node(_escape_label(net)))
194  g.add_node(nodes[-1])
195  if i > 0:
196  g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
197  return nodes
198 
199 
200 def _draw_steps(steps, g, skip_step_edges=False): # noqa
201  kMaxParallelSteps = 3
202 
203  def get_label():
204  label = [step.name + '\n']
205  if step.report_net:
206  label.append('Reporter: {}'.format(step.report_net))
207  if step.should_stop_blob:
208  label.append('Stopper: {}'.format(step.should_stop_blob))
209  if step.concurrent_substeps:
210  label.append('Concurrent')
211  if step.only_once:
212  label.append('Once')
213  return '\n'.join(label)
214 
215  def substep_edge(start, end):
216  return pydot.Edge(start, end, arrowhead='dot', style='dashed')
217 
218  nodes = []
219  for i, step in enumerate(steps):
220  parallel = step.concurrent_substeps
221 
222  nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
223  g.add_node(nodes[-1])
224 
225  if i > 0 and not skip_step_edges:
226  g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
227 
228  if step.network:
229  sub_nodes = _draw_nets(step.network, g)
230  elif step.substep:
231  if parallel:
232  sub_nodes = _draw_steps(
233  step.substep[:kMaxParallelSteps], g, skip_step_edges=True)
234  else:
235  sub_nodes = _draw_steps(step.substep, g)
236  else:
237  raise ValueError('invalid step')
238 
239  if parallel:
240  for sn in sub_nodes:
241  g.add_edge(substep_edge(nodes[-1], sn))
242  if len(step.substep) > kMaxParallelSteps:
243  ellipsis = pydot.Node('{} more steps'.format(
244  len(step.substep) - kMaxParallelSteps), **OP_STYLE)
245  g.add_node(ellipsis)
246  g.add_edge(substep_edge(nodes[-1], ellipsis))
247  else:
248  g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
249 
250  return nodes
251 
252 
253 def GetPlanGraph(plan_def, name=None, rankdir='TB'):
254  graph = pydot.Dot(name, rankdir=rankdir)
255  _draw_steps(plan_def.execution_step, graph)
256  return graph
257 
258 
259 def GetGraphInJson(operators_or_net, output_filepath):
260  operators, _ = _rectify_operator_and_name(operators_or_net, None)
261  blob_strid_to_node_id = {}
262  node_name_counts = defaultdict(int)
263  nodes = []
264  edges = []
265  for op_id, op in enumerate(operators):
266  op_label = op.name + '/' + op.type if op.name else op.type
267  op_node_id = len(nodes)
268  nodes.append({
269  'id': op_node_id,
270  'label': op_label,
271  'op_id': op_id,
272  'type': 'op'
273  })
274  for input_name in op.input:
275  strid = _escape_label(
276  input_name + str(node_name_counts[input_name]))
277  if strid not in blob_strid_to_node_id:
278  input_node = {
279  'id': len(nodes),
280  'label': input_name,
281  'type': 'blob'
282  }
283  blob_strid_to_node_id[strid] = len(nodes)
284  nodes.append(input_node)
285  else:
286  input_node = nodes[blob_strid_to_node_id[strid]]
287  edges.append({
288  'source': blob_strid_to_node_id[strid],
289  'target': op_node_id
290  })
291  for output_name in op.output:
292  strid = _escape_label(
293  output_name + str(node_name_counts[output_name]))
294  if strid in blob_strid_to_node_id:
295  # we are overwriting an existing blob. need to update the count.
296  node_name_counts[output_name] += 1
297  strid = _escape_label(
298  output_name + str(node_name_counts[output_name]))
299 
300  if strid not in blob_strid_to_node_id:
301  output_node = {
302  'id': len(nodes),
303  'label': output_name,
304  'type': 'blob'
305  }
306  blob_strid_to_node_id[strid] = len(nodes)
307  nodes.append(output_node)
308  edges.append({
309  'source': op_node_id,
310  'target': blob_strid_to_node_id[strid]
311  })
312 
313  with open(output_filepath, 'w') as f:
314  json.dump({'nodes': nodes, 'edges': edges}, f)
315 
316 
317 # A dummy minimal PNG image used by GetGraphPngSafe as a
318 # placeholder when rendering fail to run.
319 _DummyPngImage = (
320  b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00'
321  b'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00'
322  b'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
323 
324 
325 def GetGraphPngSafe(func, *args, **kwargs):
326  """
327  Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns
328  and empty image instead of throwing Exception
329  """
330  try:
331  graph = func(*args, **kwargs)
332  if not isinstance(graph, pydot.Dot):
333  raise ValueError("func is expected to return pydot.Dot")
334  return graph.create_png()
335  except Exception as e:
336  logger.error("Failed to draw graph: {}".format(e))
337  return _DummyPngImage
338 
339 
340 def main():
341  parser = argparse.ArgumentParser(description="Caffe2 net drawer.")
342  parser.add_argument(
343  "--input",
344  type=str, required=True,
345  help="The input protobuf file."
346  )
347  parser.add_argument(
348  "--output_prefix",
349  type=str, default="",
350  help="The prefix to be added to the output filename."
351  )
352  parser.add_argument(
353  "--minimal", action="store_true",
354  help="If set, produce a minimal visualization."
355  )
356  parser.add_argument(
357  "--minimal_dependency", action="store_true",
358  help="If set, only draw minimal dependency."
359  )
360  parser.add_argument(
361  "--append_output", action="store_true",
362  help="If set, append the output blobs to the operator names.")
363  parser.add_argument(
364  "--rankdir", type=str, default="LR",
365  help="The rank direction of the pydot graph."
366  )
367  args = parser.parse_args()
368  with open(args.input, 'r') as fid:
369  content = fid.read()
370  graphs = utils.GetContentFromProtoString(
371  content, {
372  caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
373  caffe2_pb2.NetDef: lambda x: {x.name: x.op},
374  }
375  )
376  for key, operators in viewitems(graphs):
377  if args.minimal:
378  graph = GetPydotGraphMinimal(
379  operators,
380  name=key,
381  rankdir=args.rankdir,
382  node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
383  minimal_dependency=args.minimal_dependency)
384  else:
385  graph = GetPydotGraph(
386  operators,
387  name=key,
388  rankdir=args.rankdir,
389  node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
390  filename = args.output_prefix + graph.get_name() + '.dot'
391  graph.write(filename, format='raw')
392  pdf_filename = filename[:-3] + 'pdf'
393  try:
394  graph.write_pdf(pdf_filename)
395  except Exception:
396  print(
397  'Error when writing out the pdf file. Pydot requires graphviz '
398  'to convert dot files to pdf, and you may not have installed '
399  'graphviz. On ubuntu this can usually be installed with "sudo '
400  'apt-get install graphviz". We have generated the .dot file '
401  'but will not be able to generate pdf file for now.'
402  )
403 
404 
405 if __name__ == '__main__':
406  main()