3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto
import caffe2_pb2
11 def gen_do_gradient(op, g_output):
13 Generates gradient Do operator, given forward Do op and a list 14 of gradient blobs corresponding to forward op's outputs 15 Returns a gradient op and a list of blobs corresponding to input gradients 18 subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name = \
19 _do_op_sanity_check_and_process(op)
21 assert len(g_output) == len(op.output), \
22 "Different number of gradient blobs and Do op outputs" 24 grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
25 g_output = deduped_g_output
38 op_output = [str(o)
for o
in op.output]
39 op_output = op_output[:-1]
40 op_input = [str(i)
for i
in op.input]
41 op_input = op_input[:-1]
43 ordered_inner_output_blob_names = [outer_to_inner_map[o]
for o
in op_output]
45 backward_pass_initial_grad_map = {}
47 for inner_output_name, outer_grad_output_name
in \
48 zip(ordered_inner_output_blob_names, g_output):
51 if outer_grad_output_name:
52 inner_grad_output_name = inner_output_name +
"/_DO_OPERATOR_INNER_GRAD_" 53 backward_pass_initial_grad_map[BlobReference(inner_output_name)] = \
54 BlobReference(inner_grad_output_name)
55 initial_grad_map[inner_grad_output_name] = str(outer_grad_output_name)
56 assert len(initial_grad_map) > 0,
"Empty initial gradient map for Do op" 58 inner_grad_ops, inner_grad_names_map = _gen_subgradient_pass(
59 subnet, backward_pass_initial_grad_map)
61 if len(inner_grad_ops) == 0:
67 new_blob_bindings = {}
68 for outer_input_name
in op_input:
69 inner_input_name = outer_to_inner_map[outer_input_name]
70 if inner_input_name
in inner_grad_names_map:
71 inner_grad_input_name = inner_grad_names_map[inner_input_name]
72 outer_grad_input_name = outer_input_name +
"_grad" 113 new_inner_grad_input_name = \
114 inner_input_name +
"/_DO_OPERATOR_INNER_GRAD_COPY_" 115 grad_copy_ops.append(_prepare_blob_copy_op(
116 inner_grad_input_name, new_inner_grad_input_name))
118 new_blob_bindings[new_inner_grad_input_name] = outer_grad_input_name
119 new_op_outputs.append(outer_grad_input_name)
120 g_input.append(outer_grad_input_name)
125 overwritten_names = set()
126 saved_local_blob_names = set()
127 for grad_op
in inner_grad_ops:
128 grad_op_input = [str(i)
for i
in grad_op.input]
129 grad_op_output = [str(o)
for o
in grad_op.output]
130 for grad_op_input_name
in grad_op_input:
131 if grad_op_input_name
in overwritten_names:
134 outer_name = inner_to_outer_map.get(grad_op_input_name,
None)
137 outer_name = initial_grad_map.get(grad_op_input_name,
None)
139 outer_name = str(outer_name)
140 if outer_name
not in new_op_inputs:
141 new_op_inputs.append(outer_name)
143 new_blob_bindings[grad_op_input_name] = outer_name
147 saved_local_blob_names.add(grad_op_input_name)
148 overwritten_names.update(grad_op_output)
151 inner_grad_ops += grad_copy_ops
153 gradient_do_def = _prepare_gradient_do_op(
156 grad_ops=inner_grad_ops,
157 inputs=new_op_inputs,
158 outputs=new_op_outputs,
159 blob_bindings=new_blob_bindings,
160 saved_fwd_blobs=saved_local_blob_names,
161 workspace_blob_name=workspace_blob_name)
162 grad_ops.append(gradient_do_def)
164 _do_op_sanity_check_and_process(gradient_do_def)
166 return grad_ops, g_input
169 def dedupe_g_output(op, g_output):
175 deduped_g_output = []
177 for output_name, grad_name
in zip(op.output, g_output):
179 deduped_g_output.append(grad_name)
182 if output_name
in init_grad_map:
183 deduped_g_output.append(init_grad_map[output_name])
185 if grad_name
not in init_grad_map.values():
186 init_grad_map[output_name] = grad_name
187 deduped_g_output.append(grad_name)
189 deduped_grad_name = output_name +
"_" + grad_name +
"_DEDUP" 190 assert deduped_grad_name
not in init_grad_map.values()
191 grad_copy_op = caffe2_pb2.OperatorDef()
192 grad_copy_op.type =
"Copy" 193 grad_copy_op.input.extend([grad_name])
194 grad_copy_op.output.extend([deduped_grad_name])
195 grad_ops.append(grad_copy_op)
196 deduped_g_output.append(deduped_grad_name)
197 init_grad_map[output_name] = deduped_grad_name
198 return grad_ops, deduped_g_output
201 def gen_while_gradient(op, g_output):
203 Generates gradient While operator 206 assert op.type ==
"While",
"Expected While op" 207 assert len(op.input) > 0,
"Expected at least one input in While op" 209 assert len(op.output) == len(g_output), \
210 "Different number of gradient blobs and While op outputs" 212 grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
213 g_output = deduped_g_output
216 op_output = [str(o)
for o
in op.output]
217 for output_name, grad_output_name
in zip(op_output, g_output):
219 init_grad_map[BlobReference(output_name)] = \
220 BlobReference(grad_output_name)
221 assert len(init_grad_map) > 0,
"Empty initial gradient map for While op" 223 loop_net = _get_net_argument(op,
"loop_net")
224 assert loop_net,
"Expected loop subnet in While op" 225 assert len(loop_net.op) == 1
and loop_net.op[0].type ==
"Do", \
226 "Gradient While op requires single Do op as a loop body" 227 do_op = loop_net.op[0]
228 do_args = _get_do_arguments(do_op)
229 assert "reuse_workspace" not in do_args
or not do_args[
"reuse_workspace"], \
230 "Gradient While op requires Do loop body op without reuse_workspace set" 232 assert len(do_op.output) > 0,
"Expected Do op with at least one output" 233 workspace_blob = do_op.output[-1]
235 loop_grad_net, loop_grad_map, loop_input_names, loop_output_names = \
236 _gen_subnet_gradient(loop_net, init_grad_map)
237 assert loop_grad_net,
"Failed to get gradient net for loop body in While op" 239 grad_ops += _prepare_gradient_while_ops(
241 input_names=loop_input_names,
242 output_names=loop_output_names,
243 loop_grad_net=loop_grad_net,
244 workspace_blob=workspace_blob,
245 init_grad_map=init_grad_map,
246 loop_grad_map=loop_grad_map)
248 op_input = [str(i)
for i
in op.input]
249 g_input = [loop_grad_map.get(i,
None)
for i
in op_input]
250 return grad_ops, g_input
261 def _prepare_gradient_while_ops(
262 fwd_op, input_names, output_names, loop_grad_net, workspace_blob,
263 init_grad_map, loop_grad_map):
264 gradient_while_def = caffe2_pb2.OperatorDef()
265 gradient_while_def.CopyFrom(fwd_op)
266 if gradient_while_def.name:
267 gradient_while_def.name +=
"_grad" 269 loop_net_arg = caffe2_pb2.Argument()
270 loop_net_arg.name =
"loop_net" 271 loop_net_arg.n.CopyFrom(loop_grad_net)
273 cond_net_arg = caffe2_pb2.Argument()
274 cond_net_arg.name =
"cond_net" 278 cond_net = Net(
'gradient_loop_cond_net')
279 cond_init_net = Net(
'gradient_loop_cond_net_init')
280 cond_blob = cond_net.NextScopedBlob(cond_net.Name() +
'/cond')
281 cond_init_net.HasScope(workspace_blob, cond_blob)
282 cond_net.HasScope(workspace_blob, cond_blob)
283 for blob, init_grad_blob
in init_grad_map.items():
284 blob_name = str(blob)
285 init_grad_blob_name = str(init_grad_blob)
286 if blob_name
in loop_grad_map
and \
287 loop_grad_map[blob_name] != init_grad_blob_name:
289 BlobReference(loop_grad_map[blob_name]), init_grad_blob)
291 init_grad_blob, BlobReference(loop_grad_map[blob_name]))
292 cond_net_arg.n.CopyFrom(cond_net.Proto())
294 del gradient_while_def.arg[:]
295 gradient_while_def.arg.extend([loop_net_arg, cond_net_arg])
297 del gradient_while_def.control_input[:]
298 del gradient_while_def.input[:]
299 gradient_while_def.input.extend(
300 [str(cond_blob).encode(
'utf-8')] + list(input_names))
301 del gradient_while_def.output[:]
302 gradient_while_def.output.extend(output_names)
303 gradient_while_def.is_gradient_op =
True 304 return [o
for o
in cond_init_net.Proto().op] + [gradient_while_def]
307 def _get_do_arguments(do_op):
308 assert do_op.type ==
"Do",
"Expected Do op" 310 for arg
in do_op.arg:
313 if arg.name ==
"net":
314 assert arg.n,
"Expected non empty net argument" 316 elif arg.name ==
"reuse_workspace":
317 assert arg.i,
"Expected non empty reuse_workspace argument" 318 args[
"reuse_workspace"] = bool(arg.i)
319 elif arg.name ==
"inner_blobs":
320 assert arg.strings,
"Expected non empty inner_blobs argument" 321 args[
"inner_blobs"] = arg.strings
322 elif arg.name ==
"outer_blobs_idx":
323 assert arg.ints,
"Expected non empty outer_blobs_idx argument" 324 args[
"outer_blobs_idx"] = arg.ints
328 def gen_if_gradient(op, g_output):
330 Generates gradient If operator, given forward If op and a list 331 of gradient blobs corresponding to forward op's outputs 332 Returns a gradient op and a list of blobs corresponding to input gradients 335 assert op.type ==
"If",
"Expected If op" 337 assert len(op.input) > 0,
"Expected at least one input in If op" 339 assert len(op.output) == len(g_output), \
340 "Different number of gradient blobs and If op outputs" 342 grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
343 g_output = deduped_g_output
346 op_input = [str(i)
for i
in op.input]
347 op_output = [str(o)
for o
in op.output]
348 for output_name, grad_output_name
in zip(op_output, g_output):
350 init_grad_map[BlobReference(output_name)] = \
351 BlobReference(grad_output_name)
353 assert len(init_grad_map) > 0,
"Empty initial gradient map for If op" 356 then_net = _get_net_argument(op,
"then_net")
357 assert then_net,
"Expected then subnet in If op" 358 then_grad_net, then_grad_map, then_input_names, then_output_names = \
359 _gen_subnet_gradient(then_net, init_grad_map)
360 assert then_grad_net,
"Failed to get gradient net for then in If op" 361 grad_map.update(then_grad_map)
363 else_input_names = set()
364 else_output_names = set()
367 else_net = _get_net_argument(op,
"else_net")
369 else_grad_net, else_grad_map, else_input_names, else_output_names = \
370 _gen_subnet_gradient(else_net, init_grad_map)
371 assert else_grad_net,
"Failed to get gradient net for else in If op" 374 for else_blob, else_grad_blob
in else_grad_map.items():
375 if else_blob
in then_grad_map:
376 then_grad_blob = then_grad_map[else_blob]
382 if then_grad_blob != else_grad_blob:
383 init_grad_name = init_grad_map[else_blob] \
384 if else_blob
in init_grad_map
else None 386 if then_grad_blob == init_grad_name:
387 grad_map[else_blob] = else_grad_blob
388 elif else_grad_blob == init_grad_name:
389 grad_map[else_blob] = then_grad_blob
391 raise "Unexpected grad blob name " + else_blob +
", " + \
392 else_grad_blob +
", " + then_grad_blob
394 grad_map[else_blob] = else_grad_blob
398 then_other_output_names = \
399 then_output_names - (then_output_names & else_output_names)
400 then_other_grad_output_names = set(
401 [o
for o
in then_other_output_names
if o
in then_grad_map.values()])
402 zero_then = _gen_grad_zero_init_ops(
403 init_grad_map, then_grad_map, then_other_grad_output_names)
405 else_grad_net.op.extend(zero_then)
406 elif len(zero_then) > 0:
407 else_grad_net = caffe2_pb2.NetDef()
408 else_grad_net.CopyFrom(then_grad_net)
409 if else_grad_net.name:
410 else_grad_net.name +=
"_auto_else_zero_blobs_" 411 del else_grad_net.op[:]
412 else_grad_net.op.extend(zero_then)
413 del else_grad_net.external_input[:]
414 del else_grad_net.external_output[:]
416 else_other_output_names = \
417 else_output_names - (then_output_names & else_output_names)
418 else_other_grad_output_names = set(
419 [o
for o
in else_other_output_names
if o
in else_grad_map.values()])
420 zero_else = _gen_grad_zero_init_ops(
421 init_grad_map, else_grad_map, else_other_grad_output_names)
422 then_grad_net.op.extend(zero_else)
424 output_names = list(then_output_names | else_output_names)
425 input_names = then_input_names | else_input_names
427 input_names = [op_input[0]] + list(input_names - set(op_input[0]))
428 gradient_if_def = _prepare_gradient_if_op(
430 input_names=input_names,
431 output_names=output_names,
432 then_grad_net=then_grad_net,
433 else_grad_net=else_grad_net)
434 g_input = [grad_map.get(i,
None)
for i
in op_input]
435 return grad_ops + [gradient_if_def], g_input
438 def _gen_subnet_gradient(subnet, init_grad):
439 grad_ops, grad_names_map = _gen_subgradient_pass(
444 for grad_op
in grad_ops:
445 for grad_op_input
in grad_op.input:
446 if str(grad_op_input)
not in output_names:
447 input_names.add(str(grad_op_input))
448 for grad_op_output
in grad_op.output:
449 output_names.add(str(grad_op_output))
451 gradient_net_def = caffe2_pb2.NetDef()
452 gradient_net_def.CopyFrom(subnet)
453 if gradient_net_def.name:
454 gradient_net_def.name +=
"_grad" 455 del gradient_net_def.op[:]
456 gradient_net_def.op.extend(grad_ops)
457 del gradient_net_def.external_input[:]
458 del gradient_net_def.external_output[:]
460 return gradient_net_def, grad_names_map, input_names, output_names
463 def _get_net_argument(op, net_name):
465 if arg.name
and arg.name == net_name:
466 assert arg.n,
"Expected non empty net argument " + net_name
471 def _gen_subgradient_pass(subnet, init_grad):
473 subnet_ir = IR(subnet.op)
474 grad_ops, grad_blob_map = \
475 subnet_ir.GetBackwardPass(init_grad)
477 for b, g
in grad_blob_map.items():
478 grad_names_map[str(b)] = str(g)
479 return grad_ops, grad_names_map
482 def _do_op_sanity_check_and_process(op):
483 assert op.type ==
"Do",
"Expected Do op" 485 subnet = _get_net_argument(op,
"net")
486 assert subnet,
"No net argument found in Do op" 489 outer_blobs_idx =
None 491 if arg.name
and arg.name ==
"inner_blobs":
492 assert not inner_blobs,
"inner_blobs redefinition" 493 assert arg.strings
and len(arg.strings) > 0, \
494 "Empty inner_blobs argument in Do op" 495 inner_blobs = [s.decode(
'utf-8')
for s
in arg.strings]
496 if arg.name
and arg.name ==
"outer_blobs_idx":
497 assert not outer_blobs_idx,
"outer_blobs_idx redefinition" 498 assert arg.ints
and len(arg.ints) > 0, \
499 "Empty outer_blobs_idx argument in Do op" 500 outer_blobs_idx = arg.ints
501 if inner_blobs
and outer_blobs_idx:
504 assert inner_blobs,
"No inner_blobs argument found in Do op" 505 assert outer_blobs_idx,
"No outer_blobs_idx argument found in Do op" 507 assert len(inner_blobs) == len(outer_blobs_idx), \
508 "Arguments inner_blobs and outer_blobs_idx of different length in Do op" 510 all_inner_blobs = set(inner_blobs)
511 assert len(all_inner_blobs) == len(inner_blobs), \
512 "Found duplicates in inner_blobs in Do op" 514 op_input = [str(i)
for i
in op.input]
515 assert len(op_input) > 0,
"Expected at least one input blob" 517 input_workspace_blob_name = op_input[-1]
518 op_input = op_input[:-1]
520 op_output = [str(o)
for o
in op.output]
521 assert len(op_output) > 0,
"Expected at least one output blob" 523 workspace_blob_name = op_output[-1]
524 assert input_workspace_blob_name == workspace_blob_name, \
525 "Expected same input/output workspace blob" 526 op_output = op_output[:-1]
528 all_op_input_blob_names = set(op_input)
529 assert len(all_op_input_blob_names) == len(op_input), \
530 "Found duplicates in Do op inputs" 531 all_op_output_blob_names = set(op_output)
532 assert len(all_op_output_blob_names) == len(op_output), \
533 "Found duplicates in Do op outputs" 535 ordered_outer_blob_names = op_input + op_output
536 all_outer_blob_names = set(ordered_outer_blob_names)
537 used_outer_blob_names = set()
538 outer_to_inner_map = {}
539 inner_to_outer_map = {}
540 for inner_name, outer_blob_idx
in zip(inner_blobs, outer_blobs_idx):
541 assert outer_blob_idx >= 0
and \
542 outer_blob_idx < len(ordered_outer_blob_names), \
543 "Outer blob index is out of bounds in Do op" 544 outer_name = ordered_outer_blob_names[outer_blob_idx]
545 assert outer_name
not in used_outer_blob_names, \
546 "Reusage of outer blob name " + outer_name +
" in Do op" 547 used_outer_blob_names.add(outer_name)
548 outer_to_inner_map[outer_name] = inner_name
549 inner_to_outer_map[inner_name] = outer_name
551 assert len(used_outer_blob_names) == len(all_outer_blob_names), \
552 "Not all outer blob names are used in blob bindings in Do op" 554 return subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name
557 def _prepare_blob_copy_op(from_name, to_name):
558 copy_op_def = caffe2_pb2.OperatorDef()
559 copy_op_def.type =
"Copy" 560 copy_op_def.input.extend([from_name])
561 copy_op_def.output.extend([to_name])
565 def _prepare_gradient_do_op(
566 fwd_op, fwd_net, grad_ops, inputs, outputs, blob_bindings, saved_fwd_blobs,
567 workspace_blob_name):
568 gradient_net_def = caffe2_pb2.NetDef()
569 gradient_net_def.CopyFrom(fwd_net)
570 if gradient_net_def.name:
571 gradient_net_def.name +=
"_grad" 572 del gradient_net_def.op[:]
573 gradient_net_def.op.extend(grad_ops)
574 del gradient_net_def.external_input[:]
575 del gradient_net_def.external_output[:]
577 gradient_do_def = caffe2_pb2.OperatorDef()
578 gradient_do_def.CopyFrom(fwd_op)
579 if gradient_do_def.name
and len(gradient_do_def.name) > 0:
580 gradient_do_def.name +=
"_grad" 582 del gradient_do_def.input[:]
583 gradient_do_def.input.extend(inputs)
585 gradient_do_def.input.append(workspace_blob_name)
586 del gradient_do_def.output[:]
587 gradient_do_def.output.extend(outputs)
589 gradient_do_def.output.append(workspace_blob_name)
591 net_arg = caffe2_pb2.Argument()
593 net_arg.n.CopyFrom(gradient_net_def)
595 ordered_new_outer_names = inputs + outputs
596 inner_blobs = blob_bindings.keys()
597 new_outer_blobs_idx = [ordered_new_outer_names.index(blob_bindings[b])
598 for b
in inner_blobs]
600 inner_blobs_arg = caffe2_pb2.Argument()
601 inner_blobs_arg.name =
"inner_blobs" 602 inner_blobs_arg.strings.extend([b.encode(
'utf-8')
for b
in inner_blobs])
604 outer_blobs_idx_arg = caffe2_pb2.Argument()
605 outer_blobs_idx_arg.name =
"outer_blobs_idx" 606 outer_blobs_idx_arg.ints.extend(new_outer_blobs_idx)
608 saved_blobs_arg = caffe2_pb2.Argument()
609 saved_blobs_arg.name =
"saved_fwd_blobs" 610 saved_blobs_arg.strings.extend(
611 [b.encode(
'utf-8')
for b
in saved_fwd_blobs])
613 del gradient_do_def.arg[:]
614 gradient_do_def.arg.extend([
615 net_arg, inner_blobs_arg, outer_blobs_idx_arg, saved_blobs_arg])
616 del gradient_do_def.control_input[:]
618 gradient_do_def.is_gradient_op =
True 620 return gradient_do_def
623 def _gen_grad_zero_init_ops(init_grad_map, grad_map, grad_output_names):
625 for grad_output
in grad_output_names:
629 for o, g
in grad_map.items():
633 assert output_name,
"Unknown gradient output " + grad_output
637 if output_name
in init_grad_map:
638 init_grad_name = init_grad_map[output_name]
640 if init_grad_name != grad_output:
641 grad_init_op = caffe2_pb2.OperatorDef()
642 grad_init_op.type =
"Copy" 643 grad_init_op.input.extend([str(init_grad_name)])
644 grad_init_op.output.extend([str(grad_output)])
646 grad_init_op = caffe2_pb2.OperatorDef()
647 grad_init_op.type =
"ConstantFill" 648 grad_init_op.input.extend([output_name])
649 grad_init_op.output.extend([grad_output])
650 value_arg = caffe2_pb2.Argument()
651 value_arg.name =
"value" 653 grad_init_op.arg.extend([value_arg])
656 grad_init_ops.append(grad_init_op)
660 def _prepare_gradient_if_op(
661 fwd_op, input_names, output_names, then_grad_net, else_grad_net):
662 gradient_if_def = caffe2_pb2.OperatorDef()
663 gradient_if_def.CopyFrom(fwd_op)
664 del gradient_if_def.input[:]
665 gradient_if_def.input.extend(input_names)
666 del gradient_if_def.output[:]
667 gradient_if_def.output.extend(output_names)
669 then_net_arg = caffe2_pb2.Argument()
670 then_net_arg.name =
"then_net" 671 then_net_arg.n.CopyFrom(then_grad_net)
672 gradient_args = [then_net_arg]
674 else_net_arg = caffe2_pb2.Argument()
675 else_net_arg.name =
"else_net" 676 else_net_arg.n.CopyFrom(else_grad_net)
677 gradient_args.append(else_net_arg)
679 del gradient_if_def.arg[:]
680 gradient_if_def.arg.extend(gradient_args)
681 if gradient_if_def.name:
682 gradient_if_def.name +=
"_grad" 683 del gradient_if_def.control_input[:]
684 gradient_if_def.is_gradient_op =
True 685 return gradient_if_def