Caffe2 - Python API
A deep learning, cross platform ML framework
net_printer.py
1 ## @package net_printer
2 # Module caffe2.python.net_printer
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto.caffe2_pb2 import OperatorDef, NetDef
9 from caffe2.python.checkpoint import Job
10 from caffe2.python.core import Net, ExecutionStep, Plan
11 from caffe2.python.task import Task, TaskGroup, WorkspaceType, TaskOutput
12 from collections import defaultdict
13 from contextlib import contextmanager
14 from copy import copy
15 from future.utils import viewkeys
16 from itertools import chain
17 from six import binary_type, text_type
18 
19 
20 class Visitor(object):
21  @classmethod
22  def register(cls, Type):
23  if not(hasattr(cls, 'visitors')):
24  cls.visitors = []
25 
26  def _register(func):
27  cls.visitors.append((Type, func))
28  return func
29 
30  return _register
31 
32  def __call__(self, obj, *args, **kwargs):
33  if obj is None:
34  return
35  for Type, func in self.__class__.visitors:
36  if isinstance(obj, Type):
37  return func(self, obj, *args, **kwargs)
38  raise TypeError('%s: unsupported object type: %s' % (
39  self.__class__.__name__, type(obj)))
40 
41 
43  PREFIXES_TO_IGNORE = {'distributed_ctx_init'}
44 
45  def __init__(self):
46  self.workspaces = defaultdict(lambda: defaultdict(lambda: 0))
47  self.workspace_ctx = []
48 
49  @property
50  def workspace(self):
51  return self.workspace_ctx[-1]
52 
53  @contextmanager
54  def set_workspace(self, node=None, ws=None, do_copy=False):
55  if ws is not None:
56  ws = ws
57  elif node is not None:
58  ws = self.workspaces[str(node)]
59  else:
60  ws = self.workspace
61  if do_copy:
62  ws = copy(ws)
63  self.workspace_ctx.append(ws)
64  yield ws
65  del self.workspace_ctx[-1]
66 
67  def define_blob(self, blob):
68  self.workspace[blob] += 1
69 
70  def need_blob(self, blob):
71  if any(blob.startswith(p) for p in Analyzer.PREFIXES_TO_IGNORE):
72  return
73  assert blob in self.workspace, 'Blob undefined: %s' % blob
74 
75 
76 @Analyzer.register(OperatorDef)
77 def analyze_op(analyzer, op):
78  for x in op.input:
79  analyzer.need_blob(x)
80  for x in op.output:
81  analyzer.define_blob(x)
82 
83 
84 @Analyzer.register(Net)
85 def analyze_net(analyzer, net):
86  for x in net.Proto().op:
87  analyzer(x)
88 
89 
90 @Analyzer.register(ExecutionStep)
91 def analyze_step(analyzer, step):
92  proto = step.Proto()
93  with analyzer.set_workspace(do_copy=proto.create_workspace):
94  if proto.report_net:
95  with analyzer.set_workspace(do_copy=True):
96  analyzer(step.get_net(proto.report_net))
97  all_new_blobs = set()
98  substeps = step.Substeps() + [step.get_net(n) for n in proto.network]
99  for substep in substeps:
100  with analyzer.set_workspace(
101  do_copy=proto.concurrent_substeps) as ws_in:
102  analyzer(substep)
103  if proto.should_stop_blob:
104  analyzer.need_blob(proto.should_stop_blob)
105  if proto.concurrent_substeps:
106  new_blobs = set(viewkeys(ws_in)) - set(viewkeys(analyzer.workspace))
107  assert len(all_new_blobs & new_blobs) == 0, (
108  'Error: Blobs created by multiple parallel steps: %s' % (
109  ', '.join(all_new_blobs & new_blobs)))
110  all_new_blobs |= new_blobs
111  for x in all_new_blobs:
112  analyzer.define_blob(x)
113 
114 
115 @Analyzer.register(Task)
116 def analyze_task(analyzer, task):
117  # check that our plan protobuf is not too large (limit of 64Mb)
118  step = task.get_step()
119  plan = Plan(task.node)
120  plan.AddStep(step)
121  proto_len = len(plan.Proto().SerializeToString())
122  assert proto_len < 2 ** 26, (
123  'Due to a protobuf limitation, serialized tasks must be smaller '
124  'than 64Mb, but this task has {} bytes.' % proto_len)
125 
126  is_private = task.workspace_type() != WorkspaceType.GLOBAL
127  with analyzer.set_workspace(do_copy=is_private):
128  analyzer(step)
129 
130 
131 @Analyzer.register(TaskGroup)
132 def analyze_task_group(analyzer, tg):
133  for task in tg.tasks_by_node().tasks():
134  with analyzer.set_workspace(node=task.node):
135  analyzer(task)
136 
137 
138 @Analyzer.register(Job)
139 def analyze_job(analyzer, job):
140  analyzer(job.init_group)
141  analyzer(job.epoch_group)
142 
143 
144 def analyze(obj):
145  """
146  Given a Job, visits all the execution steps making sure that:
147  - no undefined blobs will be found during excution
148  - no blob with same name is defined in concurrent steps
149  """
150  Analyzer()(obj)
151 
152 
153 class Text(object):
154  def __init__(self):
155  self._indent = 0
156  self._lines_in_context = [0]
157  self.lines = []
158 
159  @contextmanager
160  def context(self, text):
161  if text is not None:
162  self.add('with %s:' % text)
163  self._indent += 4
164  self._lines_in_context.append(0)
165  yield
166  if text is not None:
167  if self._lines_in_context[-1] == 0:
168  self.add('pass')
169  self._indent -= 4
170  del self._lines_in_context[-1]
171 
172  def add(self, text):
173  self._lines_in_context[-1] += 1
174  self.lines.append((' ' * self._indent) + text)
175 
176  def __str__(self):
177  return '\n'.join(self.lines)
178 
179 
181  def __init__(self, factor_prefixes=False, c2_syntax=True):
182  super(Visitor, self).__init__()
183  super(Text, self).__init__()
184  self.factor_prefixes = factor_prefixes
185  self.c2_syntax = c2_syntax
186  self.c2_net_name = None
187 
188 
189 def _sanitize_str(s):
190  if isinstance(s, text_type):
191  sanitized = s
192  elif isinstance(s, binary_type):
193  sanitized = s.decode('ascii', errors='ignore')
194  else:
195  sanitized = str(s)
196  if len(sanitized) < 64:
197  return "'%s'" % sanitized
198  else:
199  return "'%s'" % sanitized[:64] + '...<+len=%d>' % (len(sanitized) - 64)
200 
201 
202 def _arg_val(arg):
203  if arg.HasField('f'):
204  return str(arg.f)
205  if arg.HasField('i'):
206  return str(arg.i)
207  if arg.HasField('s'):
208  return _sanitize_str(arg.s)
209  if arg.floats:
210  return str(list(arg.floats))
211  if arg.ints:
212  return str(list(arg.ints))
213  if arg.strings:
214  return str([_sanitize_str(s) for s in arg.strings])
215  return '[]'
216 
217 
218 def commonprefix(m):
219  "Given a list of strings, returns the longest common prefix"
220  if not m:
221  return ''
222  s1 = min(m)
223  s2 = max(m)
224  for i, c in enumerate(s1):
225  if c != s2[i]:
226  return s1[:i]
227  return s1
228 
229 
230 def format_value(val):
231  if isinstance(val, list):
232  return '[%s]' % ', '.join("'%s'" % str(v) for v in val)
233  else:
234  return str(val)
235 
236 
237 def factor_prefix(vals, do_it):
238  vals = [format_value(v) for v in vals]
239  prefix = commonprefix(vals) if len(vals) > 1 and do_it else ''
240  joined = ', '.join(v[len(prefix):] for v in vals)
241  return '%s[%s]' % (prefix, joined) if prefix else joined
242 
243 
244 def call(op, inputs=None, outputs=None, factor_prefixes=False):
245  if not inputs:
246  inputs = ''
247  else:
248  inputs_v = [a for a in inputs if not isinstance(a, tuple)]
249  inputs_kv = [a for a in inputs if isinstance(a, tuple)]
250  inputs = ', '.join(
251  x
252  for x in chain(
253  [factor_prefix(inputs_v, factor_prefixes)],
254  ('%s=%s' % kv for kv in inputs_kv),
255  )
256  if x
257  )
258  call = '%s(%s)' % (op, inputs)
259  return call if not outputs else '%s = %s' % (
260  factor_prefix(outputs, factor_prefixes), call)
261 
262 
263 def format_device_option(dev_opt):
264  if not dev_opt or not (
265  dev_opt.device_type or dev_opt.cuda_gpu_id or dev_opt.node_name):
266  return None
267  return call(
268  'DeviceOption',
269  [dev_opt.device_type, dev_opt.cuda_gpu_id, "'%s'" % dev_opt.node_name])
270 
271 
272 @Printer.register(OperatorDef)
273 def print_op(text, op):
274  args = [(a.name, _arg_val(a)) for a in op.arg]
275  dev_opt_txt = format_device_option(op.device_option)
276  if dev_opt_txt:
277  args.append(('device_option', dev_opt_txt))
278 
279  if text.c2_net_name:
280  text.add(call(
281  text.c2_net_name + '.' + op.type,
282  [list(op.input), list(op.output)] + args))
283  else:
284  text.add(call(
285  op.type,
286  list(op.input) + args,
287  op.output,
288  factor_prefixes=text.factor_prefixes))
289  for arg in op.arg:
290  if arg.HasField('n'):
291  with text.context('arg: %s' % arg.name):
292  text(arg.n)
293 
294 @Printer.register(NetDef)
295 def print_net_def(text, net_def):
296  if text.c2_syntax:
297  text.add(call('core.Net', ["'%s'" % net_def.name], [net_def.name]))
298  text.c2_net_name = net_def.name
299  else:
300  text.add('# net: %s' % net_def.name)
301  for op in net_def.op:
302  text(op)
303  if text.c2_syntax:
304  text.c2_net_name = None
305 
306 
307 @Printer.register(Net)
308 def print_net(text, net):
309  text(net.Proto())
310 
311 
312 def _get_step_context(step):
313  proto = step.Proto()
314  if proto.should_stop_blob:
315  return call('loop'), False
316  if proto.num_iter and proto.num_iter != 1:
317  return call('loop', [proto.num_iter]), False
318  if proto.num_concurrent_instances > 1:
319  return (
320  call('parallel',
321  [('num_instances', proto.num_concurrent_instances)]),
322  len(step.Substeps()) > 1)
323  concurrent = proto.concurrent_substeps and len(step.Substeps()) > 1
324  if concurrent:
325  return call('parallel'), True
326  if proto.report_net:
327  return call('run_once'), False
328  return None, False
329 
330 
331 @Printer.register(ExecutionStep)
332 def print_step(text, step):
333  proto = step.Proto()
334  step_ctx, do_substep = _get_step_context(step)
335  with text.context(step_ctx):
336  if proto.report_net:
337  with text.context(call('report_net', [proto.report_interval])):
338  text(step.get_net(proto.report_net))
339  substeps = step.Substeps() + [step.get_net(n) for n in proto.network]
340  for substep in substeps:
341  sub_proto = (
342  substep.Proto() if isinstance(substep, ExecutionStep) else None)
343  if sub_proto is not None and sub_proto.run_every_ms:
344  substep_ctx = call(
345  'reporter',
346  [str(substep), ('interval_ms', sub_proto.run_every_ms)])
347  elif do_substep:
348  title = (
349  'workspace'
350  if sub_proto is not None and sub_proto.create_workspace else
351  'step')
352  substep_ctx = call(title, [str(substep)])
353  else:
354  substep_ctx = None
355  with text.context(substep_ctx):
356  text(substep)
357  if proto.should_stop_blob:
358  text.add(call('yield stop_if', [proto.should_stop_blob]))
359 
360 
361 def _print_task_output(x):
362  assert isinstance(x, TaskOutput)
363  return 'Output[' + ', '.join(str(x) for x in x.names) + ']'
364 
365 
366 @Printer.register(Task)
367 def print_task(text, task):
368  outs = ', '.join(_print_task_output(o) for o in task.outputs())
369  context = [('node', task.node), ('name', task.name), ('outputs', outs)]
370  with text.context(call('Task', context)):
371  text(task.get_step())
372 
373 
374 @Printer.register(TaskGroup)
375 def print_task_group(text, tg, header=None):
376  with text.context(header or call('TaskGroup')):
377  for task in tg.tasks_by_node().tasks():
378  text(task)
379 
380 
381 @Printer.register(Job)
382 def print_job(text, job):
383  text(job.init_group, 'Job.current().init_group')
384  text(job.epoch_group, 'Job.current().epoch_group')
385  with text.context('Job.current().stop_signals'):
386  for out in job.stop_signals:
387  text.add(_print_task_output(out))
388  text(job.download_group, 'Job.current().download_group')
389  text(job.exit_group, 'Job.current().exit_group')
390 
391 
392 def to_string(obj, **kwargs):
393  """
394  Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string
395  with detailed description of the execution steps.
396  """
397  printer = Printer(**kwargs)
398  printer(obj)
399  return str(printer)
400 
401 
402 def debug_net(net):
403  """
404  Given a Net, produce another net that logs info about the operator call
405  before each operator execution. Use for debugging purposes.
406  """
407  assert isinstance(net, Net)
408  debug_net = Net(str(net))
409  assert isinstance(net, Net)
410  for op in net.Proto().op:
411  text = Text()
412  print_op(op, text)
413  debug_net.LogInfo(str(text))
414  debug_net.Proto().op.extend([op])
415  return debug_net
Module caffe2.python.workspace.
Module caffe2.python.context.