3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto.caffe2_pb2
import OperatorDef, NetDef
12 from collections
import defaultdict
13 from contextlib
import contextmanager
15 from future.utils
import viewkeys
16 from itertools
import chain
17 from six
import binary_type, text_type
22 def register(cls, Type):
23 if not(hasattr(cls,
'visitors')):
27 cls.visitors.append((Type, func))
32 def __call__(self, obj, *args, **kwargs):
35 for Type, func
in self.__class__.visitors:
36 if isinstance(obj, Type):
37 return func(self, obj, *args, **kwargs)
38 raise TypeError(
'%s: unsupported object type: %s' % (
39 self.__class__.__name__, type(obj)))
43 PREFIXES_TO_IGNORE = {
'distributed_ctx_init'}
46 self.
workspaces = defaultdict(
lambda: defaultdict(
lambda: 0))
54 def set_workspace(self, node=None, ws=None, do_copy=False):
57 elif node
is not None:
63 self.workspace_ctx.append(ws)
67 def define_blob(self, blob):
70 def need_blob(self, blob):
71 if any(blob.startswith(p)
for p
in Analyzer.PREFIXES_TO_IGNORE):
73 assert blob
in self.
workspace,
'Blob undefined: %s' % blob
76 @Analyzer.register(OperatorDef)
77 def analyze_op(analyzer, op):
81 analyzer.define_blob(x)
84 @Analyzer.register(Net)
85 def analyze_net(analyzer, net):
86 for x
in net.Proto().op:
90 @Analyzer.register(ExecutionStep)
91 def analyze_step(analyzer, step):
93 with analyzer.set_workspace(do_copy=proto.create_workspace):
95 with analyzer.set_workspace(do_copy=
True):
96 analyzer(step.get_net(proto.report_net))
98 substeps = step.Substeps() + [step.get_net(n)
for n
in proto.network]
99 for substep
in substeps:
100 with analyzer.set_workspace(
101 do_copy=proto.concurrent_substeps)
as ws_in:
103 if proto.should_stop_blob:
104 analyzer.need_blob(proto.should_stop_blob)
105 if proto.concurrent_substeps:
106 new_blobs = set(viewkeys(ws_in)) - set(viewkeys(analyzer.workspace))
107 assert len(all_new_blobs & new_blobs) == 0, (
108 'Error: Blobs created by multiple parallel steps: %s' % (
109 ', '.join(all_new_blobs & new_blobs)))
110 all_new_blobs |= new_blobs
111 for x
in all_new_blobs:
112 analyzer.define_blob(x)
115 @Analyzer.register(Task)
116 def analyze_task(analyzer, task):
118 step = task.get_step()
119 plan =
Plan(task.node)
121 proto_len = len(plan.Proto().SerializeToString())
122 assert proto_len < 2 ** 26, (
123 'Due to a protobuf limitation, serialized tasks must be smaller ' 124 'than 64Mb, but this task has {} bytes.' % proto_len)
126 is_private = task.workspace_type() != WorkspaceType.GLOBAL
127 with analyzer.set_workspace(do_copy=is_private):
131 @Analyzer.register(TaskGroup)
132 def analyze_task_group(analyzer, tg):
133 for task
in tg.tasks_by_node().tasks():
134 with analyzer.set_workspace(node=task.node):
138 @Analyzer.register(Job)
139 def analyze_job(analyzer, job):
140 analyzer(job.init_group)
141 analyzer(job.epoch_group)
146 Given a Job, visits all the execution steps making sure that: 147 - no undefined blobs will be found during excution 148 - no blob with same name is defined in concurrent steps 162 self.
add(
'with %s:' % text)
164 self._lines_in_context.append(0)
174 self.lines.append((
' ' * self.
_indent) + text)
177 return '\n'.join(self.
lines)
181 def __init__(self, factor_prefixes=False, c2_syntax=True):
182 super(Visitor, self).__init__()
183 super(Text, self).__init__()
189 def _sanitize_str(s):
190 if isinstance(s, text_type):
192 elif isinstance(s, binary_type):
193 sanitized = s.decode(
'ascii', errors=
'ignore')
196 if len(sanitized) < 64:
197 return "'%s'" % sanitized
199 return "'%s'" % sanitized[:64] +
'...<+len=%d>' % (len(sanitized) - 64)
203 if arg.HasField(
'f'):
205 if arg.HasField(
'i'):
207 if arg.HasField(
's'):
208 return _sanitize_str(arg.s)
210 return str(list(arg.floats))
212 return str(list(arg.ints))
214 return str([_sanitize_str(s)
for s
in arg.strings])
219 "Given a list of strings, returns the longest common prefix" 224 for i, c
in enumerate(s1):
230 def format_value(val):
231 if isinstance(val, list):
232 return '[%s]' %
', '.join(
"'%s'" % str(v)
for v
in val)
237 def factor_prefix(vals, do_it):
238 vals = [format_value(v)
for v
in vals]
239 prefix = commonprefix(vals)
if len(vals) > 1
and do_it
else '' 240 joined =
', '.join(v[len(prefix):]
for v
in vals)
241 return '%s[%s]' % (prefix, joined)
if prefix
else joined
244 def call(op, inputs=None, outputs=None, factor_prefixes=False):
248 inputs_v = [a
for a
in inputs
if not isinstance(a, tuple)]
249 inputs_kv = [a
for a
in inputs
if isinstance(a, tuple)]
253 [factor_prefix(inputs_v, factor_prefixes)],
254 (
'%s=%s' % kv
for kv
in inputs_kv),
258 call =
'%s(%s)' % (op, inputs)
259 return call
if not outputs
else '%s = %s' % (
260 factor_prefix(outputs, factor_prefixes), call)
263 def format_device_option(dev_opt):
264 if not dev_opt
or not (
265 dev_opt.device_type
or dev_opt.cuda_gpu_id
or dev_opt.node_name):
269 [dev_opt.device_type, dev_opt.cuda_gpu_id,
"'%s'" % dev_opt.node_name])
272 @Printer.register(OperatorDef)
273 def print_op(text, op):
274 args = [(a.name, _arg_val(a))
for a
in op.arg]
275 dev_opt_txt = format_device_option(op.device_option)
277 args.append((
'device_option', dev_opt_txt))
281 text.c2_net_name +
'.' + op.type,
282 [list(op.input), list(op.output)] + args))
286 list(op.input) + args,
288 factor_prefixes=text.factor_prefixes))
290 if arg.HasField(
'n'):
291 with text.context(
'arg: %s' % arg.name):
294 @Printer.register(NetDef)
295 def print_net_def(text, net_def):
297 text.add(call(
'core.Net', [
"'%s'" % net_def.name], [net_def.name]))
298 text.c2_net_name = net_def.name
300 text.add(
'# net: %s' % net_def.name)
301 for op
in net_def.op:
304 text.c2_net_name =
None 307 @Printer.register(Net)
308 def print_net(text, net):
312 def _get_step_context(step):
314 if proto.should_stop_blob:
315 return call(
'loop'),
False 316 if proto.num_iter
and proto.num_iter != 1:
317 return call(
'loop', [proto.num_iter]),
False 318 if proto.num_concurrent_instances > 1:
321 [(
'num_instances', proto.num_concurrent_instances)]),
322 len(step.Substeps()) > 1)
323 concurrent = proto.concurrent_substeps
and len(step.Substeps()) > 1
325 return call(
'parallel'),
True 327 return call(
'run_once'),
False 331 @Printer.register(ExecutionStep)
332 def print_step(text, step):
334 step_ctx, do_substep = _get_step_context(step)
335 with text.context(step_ctx):
337 with text.context(call(
'report_net', [proto.report_interval])):
338 text(step.get_net(proto.report_net))
339 substeps = step.Substeps() + [step.get_net(n)
for n
in proto.network]
340 for substep
in substeps:
342 substep.Proto()
if isinstance(substep, ExecutionStep)
else None)
343 if sub_proto
is not None and sub_proto.run_every_ms:
346 [str(substep), (
'interval_ms', sub_proto.run_every_ms)])
350 if sub_proto
is not None and sub_proto.create_workspace
else 352 substep_ctx = call(title, [str(substep)])
355 with text.context(substep_ctx):
357 if proto.should_stop_blob:
358 text.add(call(
'yield stop_if', [proto.should_stop_blob]))
361 def _print_task_output(x):
362 assert isinstance(x, TaskOutput)
363 return 'Output[' +
', '.join(str(x)
for x
in x.names) +
']' 366 @Printer.register(Task)
367 def print_task(text, task):
368 outs =
', '.join(_print_task_output(o)
for o
in task.outputs())
369 context = [(
'node', task.node), (
'name', task.name), (
'outputs', outs)]
370 with text.context(call(
'Task', context)):
371 text(task.get_step())
374 @Printer.register(TaskGroup)
375 def print_task_group(text, tg, header=None):
376 with text.context(header
or call(
'TaskGroup')):
377 for task
in tg.tasks_by_node().tasks():
381 @Printer.register(Job)
382 def print_job(text, job):
383 text(job.init_group,
'Job.current().init_group')
384 text(job.epoch_group,
'Job.current().epoch_group')
385 with text.context(
'Job.current().stop_signals'):
386 for out
in job.stop_signals:
387 text.add(_print_task_output(out))
388 text(job.download_group,
'Job.current().download_group')
389 text(job.exit_group,
'Job.current().exit_group')
392 def to_string(obj, **kwargs):
394 Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string 395 with detailed description of the execution steps. 404 Given a Net, produce another net that logs info about the operator call 405 before each operator execution. Use for debugging purposes. 407 assert isinstance(net, Net)
408 debug_net =
Net(str(net))
409 assert isinstance(net, Net)
410 for op
in net.Proto().op:
413 debug_net.LogInfo(str(text))
414 debug_net.Proto().op.extend([op])
Module caffe2.python.workspace.
Module caffe2.python.context.