3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from collections
import namedtuple, OrderedDict, defaultdict
9 from past.builtins
import basestring
10 from future.utils
import viewitems, viewkeys, viewvalues
11 from itertools
import chain
12 from six
import binary_type, string_types, text_type
14 from caffe2.proto
import caffe2_pb2
17 gen_do_gradient, gen_if_gradient, gen_while_gradient
27 if (sys.platform ==
'darwin' and 'leveldb' in C.registered_dbs()):
28 print(
'If you are using homebrew leveldb on a Mac OS, you might see an ' 29 'error warning you that malloc_zone_unregister() failed. This is ' 30 'not a caffe2 issue but is due to the homebrew leveldb having an ' 31 'incompatible memory allocator. It does not affect usage.')
34 DeviceScope = scope.DeviceScope
35 NameScope = scope.NameScope
44 for name, value
in caffe2_pb2.TensorProto.DataType.items():
45 setattr(DataType, name, value)
51 def _GetRegisteredOperators():
52 return set(workspace.RegisteredOperators())
55 _REGISTERED_OPERATORS = _GetRegisteredOperators()
58 def RefreshRegisteredOperators():
59 global _REGISTERED_OPERATORS
60 _REGISTERED_OPERATORS = _GetRegisteredOperators()
63 _GLOBAL_INIT_ARGS = []
67 _GLOBAL_INIT_ARGS.extend(args[1:])
71 def GetGlobalInitArgs():
72 return _GLOBAL_INIT_ARGS[:]
75 def IsOperator(op_type):
76 return IsOperatorWithEngine(op_type, engine=
'DEFAULT')
79 def IsOperatorWithEngine(op_type, engine):
80 return C.op_registry_key(op_type, engine)
in _REGISTERED_OPERATORS
83 def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None, node_name=None):
84 option = caffe2_pb2.DeviceOption()
85 option.device_type = device_type
86 option.cuda_gpu_id = cuda_gpu_id
87 if node_name
is not None:
88 option.node_name = node_name
89 if random_seed
is not None:
90 option.random_seed = random_seed
94 def device_option_equal(opt1, opt2, ignore_node_name=True, ignore_random_seed=True):
95 if not opt1
or not opt2:
97 if not ignore_node_name
and opt1.node_name != opt2.node_name:
99 if not ignore_random_seed
and opt1.random_seed != opt2.random_seed:
101 if not opt1.device_type
or not opt2.device_type:
103 return not opt1.device_type
and not opt2.device_type
104 return opt1.cuda_gpu_id == opt2.cuda_gpu_id
107 def InferBlobDevices(net):
109 Compute mapping from parameters to devices by looking at the 110 device option of the op that creates the blob has 113 for op
in net.Proto().op:
114 op_device = op.device_option
115 if op_device
is None:
116 op_device = caffe2_pb2.DeviceOption(caffe2_pb2.CPU)
119 mapping[b] = op_device
123 def InferOpBlobDevices(op):
124 device_info = C.infer_op_input_output_device(op.SerializeToString())
127 for dev_str
in device_info[0]:
128 device_option = caffe2_pb2.DeviceOption()
129 device_option.ParseFromString(dev_str)
130 input_info.append(device_option)
131 for dev_str
in device_info[1]:
132 device_option = caffe2_pb2.DeviceOption()
133 device_option.ParseFromString(dev_str)
134 output_info.append(device_option)
135 return input_info, output_info
138 def InferOpDeviceAsBlobDevices(op):
139 op_dev = op.device_option
if op.device_option
else caffe2_pb2.DeviceOption()
140 input_dev = [op_dev] * len(op.input)
141 output_dev = [op_dev] * len(op.output)
142 return input_dev, output_dev
145 GradientSlice = namedtuple(
'GradientSlice', [
'indices',
'values'])
149 """A wrapper around a blob in a net. 151 BlobReference gives us a way to refer to the network that the blob is 152 generated from. Note that blobs are, essentially, just strings in the 157 """Initializes a blob reference. 159 Note that this does not prepends the namescope. If needed, use 160 ScopedBlobReference() to prepend the existing namespace. 162 if isinstance(name, string_types):
164 elif isinstance(name, binary_type):
165 self.
_name = name.decode(
'utf-8')
167 self.
_name = str(name)
174 return hash(self.
_name)
176 def __eq__(self, other):
177 if isinstance(other, string_types):
178 return self.
_name == other
179 elif isinstance(other, binary_type):
180 return self.
_name == other.decode(
'utf-8')
181 elif isinstance(other, BlobReference):
182 return self.
_name == other._name
186 def __ne__(self, other):
187 return not(self == other)
193 return 'BlobReference("{}")'.format(self.
_name)
195 def __add__(self, other):
196 if not isinstance(other, string_types):
197 raise RuntimeError(
'Cannot add BlobReference to a non-string.')
200 def __radd__(self, other):
201 if not isinstance(other, string_types):
202 raise RuntimeError(
'Cannot add a non-string to BlobReference.')
208 def GetNameScope(self):
209 return self.
_name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]
211 def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
212 """Internal function that routes the operator generation to the 213 network's __getattr__ function. 215 inputs = []
if inputs
is None else inputs
216 if isinstance(inputs, BlobReference)
or isinstance(inputs, string_types):
219 inputs.insert(0, self)
220 return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs)
223 """A wrapper allowing one to initiate operators from a blob reference. 225 Example: for a blob reference b that comes from network n, doing 227 is equivalent to doing 230 if op_type.startswith(
'__'):
231 raise AttributeError(
'Attribute {} not found.'.format(op_type))
234 'You cannot use a blob reference that does not have a net ' 235 'source to create operators. Create the operator from an ' 236 'explicit net object.')
237 if not IsOperator(op_type):
239 'Method ' + op_type +
' is not a registered operator.' +
241 ",".join(workspace.C.nearby_opnames(op_type)) +
']' 244 op_type, *args, **kwargs)
247 additional_methods = [
249 for op
in _REGISTERED_OPERATORS
250 if '_ENGINE_' not in op
or '_ENGINE_CUDNN' in op]
251 return sorted(set(chain(
253 viewkeys(self.__dict__),
258 def ScopedName(name):
259 """prefix the name with the current scope.""" 260 if isinstance(name, binary_type):
261 name = name.decode(
'ascii')
262 return scope.CurrentNameScope() + name
265 def ScopedBlobReference(name, *args, **kwargs):
266 """Returns a blob reference with scope prefixed.""" 270 def _RectifyInputOutput(blobs, net=None):
271 """A helper function to rectify the input or output of the CreateOperator 274 if isinstance(blobs, string_types)
or isinstance(blobs, binary_type):
278 return [ScopedBlobReference(blobs, net=net)]
279 elif type(blobs)
is BlobReference:
282 elif type(blobs)
in (list, tuple):
286 if isinstance(blob, string_types)
or isinstance(blob, binary_type):
287 rectified.append(ScopedBlobReference(blob, net=net))
288 elif type(blob)
is BlobReference:
289 rectified.append(blob)
292 "I/O blob #{} of unsupported type: {} of type {}" 293 .format(len(rectified), str(blob), type(blob)))
297 "Unknown input/output type: %s of type %s." %
298 (str(blobs), type(blobs))
313 """A function wrapper that allows one to create operators based on the 314 operator type. The type should be a string corresponding to an operator 315 registered with Caffe2. 317 operator = caffe2_pb2.OperatorDef()
318 if (os.environ.get(
'CAFFE2_DEBUG')):
319 stack = traceback.format_stack()
320 operator.debug_info =
"".join(stack[:-1])
322 operator.type = operator_type
325 inputs = _RectifyInputOutput(inputs)
326 outputs = _RectifyInputOutput(outputs)
327 operator.input.extend([text_type(i)
for i
in inputs])
328 operator.output.extend([text_type(o)
for o
in outputs])
330 control_input = _RectifyInputOutput(control_input)
331 operator.control_input.extend([text_type(i)
for i
in control_input])
337 if device_option
is not None:
338 operator.device_option.CopyFrom(device_option)
339 elif scope.CurrentDeviceScope()
is not None:
340 operator.device_option.CopyFrom(scope.CurrentDeviceScope())
341 if engine
is not None:
342 operator.engine = engine
346 if 'random_seed' in kwargs:
347 operator.device_option.random_seed = kwargs[
'random_seed']
348 del kwargs[
'random_seed']
351 operator.arg.extend(arg)
353 for key, value
in viewitems(kwargs):
354 operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
356 if workspace.IsImmediate():
357 workspace.RunOperatorImmediate(operator)
361 def _RegisterPythonImpl(
362 f, grad_f=
None, python_func_type=
None, pass_workspace=
False 365 func = python_func_type(f)
367 grad_f = func.backward
369 if isinstance(f, tuple):
370 f = f[0](*f[1], **f[2])
371 if isinstance(grad_f, tuple):
372 grad_f = grad_f[0](*grad_f[1], **grad_f[2])
374 token = C.register_python_op(f, pass_workspace,
'')
376 C.register_python_gradient_op(token, grad_f)
380 def CreatePythonOperator(
384 pass_workspace=
False,
385 python_func_type=
None,
390 `f` should have a signature (inputs, outputs) 392 If `pass_workspace` is True, the signature is changed to 393 (inputs, outputs, workspace) where `workspace` is the workspace the op 394 is going to run on. This is potentially dangerous (as the op can manipulate 395 the workspace directly), use on your own risk. 397 kwargs[
"token"] = _RegisterPythonImpl(
398 f, grad_f, python_func_type, pass_workspace=pass_workspace
400 return CreateOperator(
"Python", inputs, outputs, *args, **kwargs)
403 def GetIndexFromGradientList(g_list, name):
404 """A helper function to get the index from a gradient list, None if not 406 for i, g
in enumerate(g_list):
409 elif type(g)
is GradientSlice:
410 if (g.indices == name
or g.values == name):
415 OpSSA = namedtuple(
'OpSSA', [
'op',
'in_versions',
'out_versions'])
416 GradGenMeta = namedtuple(
'GradGenMeta', [
'grad_op',
'idx',
'gradient'])
417 SparseGradGenMeta = namedtuple(
'SparseGradGenMeta', [
418 'grad_op_indices',
'idx_indices',
419 'grad_op_values',
'idx_values',
425 """A simple IR class to keep track of all intermediate representations used 426 in the gradient computation. 429 def __init__(self, operators):
448 self.
input_usages = defaultdict(
lambda: defaultdict(list))
460 def SanityCheck(self, operators):
464 if op.type ==
'StopGradient':
466 raise ValueError(
"""StopGradient's output '{}' is orphan. 467 You typically want to specify same input and output for 468 StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
471 """"Adds an op to the current IR, and update the internal states to 472 reflect the blobs and versions after the execution of the op. 489 self.ssa.append(OpSSA(op, in_versions, out_versions))
492 self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs):
493 """Checks if the gradient operators can be correctly carried out.""" 494 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
495 original_index = GetIndexFromGradientList(g_output, grad_op_input)
498 def versionMismatchInfoOut(name):
500 s +=
"Maybe you use same output blob twice for different ops?\n" 501 s +=
"== Version history of blob [{}]\n".format(name)
503 s +=
"Version (out) {} <-- {}".format(vers, op)
507 def versionMismatchInfoIn(name):
509 s +=
"Maybe the blob was overwritten by another op?\n" 510 s +=
"== Version history of blob [{}]\n".format(name)
512 s +=
"version (in) {} <-- {}".format(vers, op)
518 if original_index
is not None:
519 original_name = forward_op.output[original_index]
520 if (out_versions[original_name] !=
523 'Gradient name "%s" is expected to correspond ' 524 'to version %d of "%s", but currently we have ' 525 'version %d.\n\n' % (
526 grad_op_input, out_versions[original_name],
529 versionMismatchInfoOut(original_name))
532 elif grad_op_input
in out_versions:
533 if self.
frontier[grad_op_input] != out_versions[grad_op_input]:
535 'Gradient operator needs output "%s" at version' 536 ' %d, but currently we have version %d.\n\n' % (
537 grad_op_input, out_versions[grad_op_input],
539 ) + versionMismatchInfoOut(grad_op_input)
543 elif grad_op_input
in in_versions:
544 if (self.
frontier[grad_op_input] != in_versions[grad_op_input]):
546 'Gradient operator needs input "%s" at version ' 547 '%d, but currently we have version %d.\n\n' % (
548 grad_op_input, in_versions[grad_op_input],
550 ) + versionMismatchInfoIn(grad_op_input)
555 if grad_op_input
not in locally_generated_blobs:
557 'Blob name "%s" not in the scope of operator: ' 558 '%s\nand is not generated by any of the local ' 559 'gradient operators.' % (grad_op_input, str(forward_op))
562 def AppendSparseGenerators(self, sparse_generators):
564 for name, input_generators
in viewitems(sparse_generators):
565 for version, generators
in viewitems(input_generators):
566 if len(generators) == 1:
568 generator = generators[0]
571 assert(len(generators) == 2)
572 op1_i, idx1_i, op1_v, idx1_v, g1 = generators[0]
573 op2_i, idx2_i, op2_v, idx2_v, g2 = generators[1]
575 assert(op1_i
is None or op2_i
is None)
576 assert(op1_v
is None or op2_v
is None)
577 assert(idx1_i == 0
or idx2_i == 0)
578 assert(idx1_v == 0
or idx2_v == 0)
579 generator = SparseGradGenMeta(
580 op1_i
or op2_i, idx1_i + idx2_i,
581 op1_v
or op2_v, idx1_v + idx2_v,
586 self, fwd_op_idx, gradient_ops, g_output, g_input):
587 """Updates gradient_generators and gradient_frontier""" 588 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
589 locally_generated_blobs = []
590 sparse_generators = defaultdict(
lambda: defaultdict(list))
592 for grad_op
in gradient_ops:
594 for s
in grad_op.input:
596 s, g_output, fwd_op_idx, locally_generated_blobs)
601 locally_generated_blobs.extend([str(s)
for s
in grad_op.output])
602 for i, output
in enumerate(grad_op.output):
603 input_index = GetIndexFromGradientList(g_input, output)
604 if input_index
is not None:
605 input_name = forward_op.input[input_index]
606 input_version = in_versions[input_name]
607 g = g_input[input_index]
608 if type(g)
is GradientSlice:
614 if g.indices == output:
615 m = SparseGradGenMeta(grad_op, i,
None, 0, g)
617 assert(g.values == output)
618 m = SparseGradGenMeta(
None, 0, grad_op, i, g)
619 sparse_generators[input_name][input_version].append(m)
634 for input_index, g
in enumerate(g_input):
635 input_name = forward_op.input[input_index]
636 input_version = in_versions[input_name]
639 if type(g)
is GradientSlice:
640 if str(g.indices)
not in locally_generated_blobs
and \
641 str(g.values)
not in locally_generated_blobs:
643 SparseGradGenMeta(
None, 0,
None, 0, g))
645 if str(g)
not in locally_generated_blobs:
647 GradGenMeta(
None, 0, g))
652 for i, g
in enumerate(g_input):
654 input_name = forward_op.input[i]
655 input_version = in_versions[input_name]
658 def _GetSumOpOutputName(self, generator, input_name):
659 def remove_suffix(s, suffix):
660 if s.endswith(suffix):
661 return s[:-len(suffix)]
665 if type(g)
is GradGenMeta:
668 return grad_op.output[idx]
670 assert(type(g)
is SparseGradGenMeta)
671 op_i, idx_i, op_v, idx_v, _ = g
673 return remove_suffix(op_i.output[idx_i],
'_indices')
675 return remove_suffix(op_v.output[idx_v],
'_values')
677 return input_name +
'_grad' 679 def _SetSumOpsDeviceOption(self, sum_ops, generators):
682 for generator
in generators:
683 grad_op = generator.grad_op
if type(generator)
is GradGenMeta \
684 else generator.grad_op_values
or generator.grad_op_indices
686 if grad_op.HasField(
'device_option'):
688 op.device_option.CopyFrom(grad_op.device_option)
691 def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
692 grad_op.output[idx] = (
693 '_' + grad_op.output[idx] +
'_autosplit_{}'.format(cnt))
694 return grad_op.output[idx], cnt + 1
696 def _CheckSumOpsConflict(self, out_base_name, g):
697 if str(out_base_name) == str(g):
700 'The gradient output of empty gradient op can not ' 701 'be the same as the normal name of the current ' 704 def _MakeDenseSumOps(self, generators, out_base_name):
708 assert len(generators) > 1
711 for generator
in generators:
712 grad_op, idx, g = generator
713 assert(type(g)
is not GradientSlice)
716 first_grad_op =
False 717 out = grad_op.output[idx]
720 sum_op_input.append(out)
723 sum_op_input.append(str(g))
725 if out_base_name
in sum_op_input:
728 idx = sum_op_input.index(out_base_name)
729 sum_op_input[0], sum_op_input[idx] = (
730 sum_op_input[idx], sum_op_input[0]
732 sum_ops = [CreateOperator(
736 return sum_ops, out_base_name
738 def _MakeSparseSumOps(self, generators, out_base_name):
739 indices_concat_input = []
740 values_concat_input = []
744 for generator
in generators:
745 assert(type(generator)
is SparseGradGenMeta)
746 op_i, idx_i, op_v, idx_v, g = generator
749 indices_concat_input.append(out)
752 indices_concat_input.append(g.indices)
755 values_concat_input.append(out)
758 values_concat_input.append(g.values)
760 indices_concat_output = out_base_name +
'_indices_concat' 761 indices_concat_split = out_base_name +
'_indices_concat_split' 762 values_concat_output = out_base_name +
'_values_concat' 763 values_concat_split = out_base_name +
'_values_concat_split' 773 [indices_concat_output, indices_concat_split]],
780 [values_concat_output, values_concat_split]],
784 sum_op_output = GradientSlice(
785 indices=indices_concat_output,
786 values=values_concat_output,
788 return sum_ops, sum_op_output
790 def _MakeSumOps(self, input_name, input_version):
793 types = list(set(type(x)
for x
in generators))
794 assert(len(types) == 1)
795 if types[0]
is GradGenMeta:
798 assert(types[0]
is SparseGradGenMeta)
803 def _VerifyGradientGenerators(self, generator):
806 if len({type(g)
for g
in generator}) > 1:
808 'Automatic aggregation of a mix of sparse and dense gradients ' 809 'is not supported yet')
814 if len(generator) < 2:
817 all_gradient_names = []
818 all_device_options = []
820 if type(g)
is GradGenMeta:
822 all_gradient_names.append(g.gradient)
823 all_device_options.append(g.grad_op.device_option)
825 assert(type(g)
is SparseGradGenMeta)
826 if g.grad_op_indices:
827 all_device_options.append(g.grad_op_indices.device_option)
829 all_device_options.append(g.grad_op_values.device_option)
830 all_gradient_names.append(g.gradient.values)
833 if len(all_device_options) >= 2
and not all(
834 device_option_equal(d, all_device_options[0])
835 for d
in all_device_options[1:]):
836 raise RuntimeError(
'Unexpected behavior: not all grad ops ' 837 'have the same device option.')
841 """For each input name in the forward op, check if we will need to 842 add gradient accumulation. If so, do gradient accumulation and return 843 the list of gradient operators. 845 The criteria for doing gradient accumulation is: 846 (1) the specific input version has been used by multiple operators. 847 (2) the current fwd_op_idx is the first to use that input, i.e. in the 848 backward pass, is the last to optionally generate the gradient for 850 (3) For the operators that used the input, their gradient operators 851 have generated more than 1 gradient. 853 When accumulating operators, our current solution is to rename all the 854 created gradients with an internal intermediate name, and then add a 855 Sum() operator that adds up all the gradients. This may use more memory 856 due to intermediate storage, but is usually the fastest approach as one 857 can do one single sum for multiple intermediate gradients. 859 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
860 additional_sum_ops = []
862 for _i, input_name
in enumerate(set(forward_op.input)):
863 input_version = in_versions[input_name]
864 input_usage = self.
input_usages[input_name][input_version]
865 if (len(input_usage) <= 1
or fwd_op_idx != input_usage[0]):
872 except RuntimeError
as err:
874 "Gradients for param ''{}'' failed to verify: {}".format(
881 sum_ops, g = self.
_MakeSumOps(input_name, input_version)
882 additional_sum_ops.extend(sum_ops)
883 grad_map[input_name] = g
884 return additional_sum_ops, grad_map
886 def _AppendAutoGradGenerator(self, y, grad, autograd_op):
890 generator = GradGenMeta(
891 autograd_op, 0
if autograd_op
else None, str(grad))
897 def _GetInitGradients(self, ys):
901 for y, g
in viewitems(ys):
904 autograd_op = CreateOperator(
905 "ConstantFill", [y], [str(y) +
"_autogen_grad"],
907 gradient_ops.append(autograd_op)
908 g = autograd_op.output[0]
911 input_to_grad[str(y)] = (
912 GradientSlice(str(g[0]), str(g[1]))
913 if isinstance(g, GradientSlice)
else str(g))
916 if autograd_op
is not None:
919 return input_to_grad, gradient_ops
921 def _GenerateGradientsForForwardOp(
922 self, forward_op_idx, input_to_grad):
923 new_input_to_grad = {}
925 forward_op, in_versions, out_versions = self.
ssa[forward_op_idx]
927 input_to_grad.get(name,
None)
for name
in forward_op.output)
929 if not all(g
is None for g
in g_output)
or (
930 forward_op.type ==
"ZeroGradient"):
931 gradient_ops, g_input = GradientRegistry.GetGradientForOp(
932 forward_op, g_output)
936 forward_op_idx, gradient_ops, g_output, g_input)
938 for name, grad
in zip(forward_op.input, g_input):
943 if grad
is not None or \
944 name
not in input_to_grad
or \
945 name
in list(forward_op.output):
946 new_input_to_grad[name] = grad
948 return new_input_to_grad, gradient_ops
951 """Gets the backward pass that computes the derivatives of given blobs. 954 ys: a list or a dictionary specifying what blobs we want to compute 955 derivatives of. If the input is a list, we will automatically 956 generate their gradients with all-one values; if the input is a 957 dictionary, for any dictionary entries that are not None, we will 958 take the corresponding blobs as their gradients; for all those 959 that are None, we will auto-fill them with 1. 961 if isinstance(ys, list):
962 ys = dict((y,
None)
for y
in ys)
963 elif not isinstance(ys, dict):
964 raise TypeError(
"ys should either be a list or a dict.")
968 for y
in viewkeys(ys):
980 for forward_op_idx
in reversed(range(len(self.
ssa))):
982 forward_op_idx, all_input_to_grad)
983 all_input_to_grad.update(input_to_grad)
984 all_gradient_ops += gradient_ops
992 all_input_to_grad.update(grad_map)
993 all_gradient_ops += additional_sum_ops
999 all_input_to_grad_out = {}
1000 for key, val
in viewitems(all_input_to_grad):
1002 if (isinstance(val, string_types)
or 1003 isinstance(val, binary_type)):
1009 return all_gradient_ops, all_input_to_grad_out
1013 """GradientRegistry holds the mapping from operators to their gradients.""" 1014 gradient_registry_ = {}
1018 """A decorator for registering gradient mappings.""" 1027 def _GetGradientForOpCC(cls, op_def, g_output):
1029 def from_untyped(grad):
1031 w = C.GradientWrapper()
1035 (indices, values) = grad
1036 w = C.GradientWrapper()
1039 assert w.is_sparse()
1042 w = C.GradientWrapper()
1047 g_output = [from_untyped(grad)
for grad
in g_output]
1048 grad_defs_str, g_input = C.get_gradient_defs(
1049 op_def.SerializeToString(), g_output)
1051 def to_untyped(grad_wrapper):
1052 if grad_wrapper.is_empty():
1054 if grad_wrapper.is_sparse():
1055 return GradientSlice(grad_wrapper.indices, grad_wrapper.values)
1056 assert grad_wrapper.is_dense()
1057 return grad_wrapper.dense
1059 g_input = [to_untyped(grad_wrapper)
for grad_wrapper
in g_input]
1061 for grad_def_str
in grad_defs_str:
1062 grad_def = caffe2_pb2.OperatorDef()
1063 grad_def.ParseFromString(grad_def_str)
1064 grad_defs.append(grad_def)
1065 return grad_defs, g_input
1068 def GetGradientForOp(cls, op, g_output):
1071 except Exception
as e:
1079 "Exception when creating gradient for [{}]:{}.\nOp: \n{}".
1080 format(op.type, e, str(op))
1083 if gradient_ops
is None:
1085 if type(gradient_ops)
is not list:
1086 gradient_ops = [gradient_ops]
1087 return gradient_ops, g_input
1091 """Gets the backward pass for the list of operators. 1094 operators: a list of operators constituting the forward pass. 1095 ys: a list or a dictionary specifying what blobs we want to compute 1096 derivatives of. If the input is a list, we will automatically 1097 generate their gradients with all-one values; if the input is a 1098 dictionary, for any dictionary entries that are not None, we'll 1099 take the corresponding blobs as their gradients; for all those 1100 that are None, we will auto-fill them with 1. 1102 gradient_ops: a list of gradient operators to run. 1103 all_input_to_grads: a map from input to their corresponding 1107 return ir.GetBackwardPass(ys)
1110 GradientRegistry.RegisterGradient(
'Do')(gen_do_gradient)
1111 GradientRegistry.RegisterGradient(
'If')(gen_if_gradient)
1112 GradientRegistry.RegisterGradient(
'While')(gen_while_gradient)
1115 def get_ssa(net, blob_versions=None):
1117 Given a net, return a structure containing the version of each input and 1118 output blob used by each operator. 1121 net: either a Net or a NetDef 1122 blob_versions: (optional) map with current version number for given 1123 blob names. If not provided or blob not found, start 1126 Tuple (ssa, blob_versions) 1127 ssa: list of tuples (versioned_inputs, versioned_outputs) 1128 for each op in the net. A versioned input is a tuple 1129 (blob_name, version). 1130 blob_versions: updated map with latest version of each blob found in 1133 proto = net.Proto()
if isinstance(net, Net)
else net
1134 assert isinstance(proto, caffe2_pb2.NetDef)
1135 if blob_versions
is None:
1137 if isinstance(net, list):
1138 return [get_ssa(n, blob_versions)
for n
in net], blob_versions
1139 for i
in proto.external_input:
1140 if i
not in blob_versions:
1141 blob_versions[str(i)] = 0
1144 if not proto.external_input:
1146 if i
not in blob_versions:
1147 blob_versions[i] = 0
1148 inputs = [(str(i), blob_versions.get(str(i), 0))
for i
in op.input]
1150 blob_versions[str(o)] = blob_versions.get(str(o), 0) + 1
1151 outputs = [(str(o), blob_versions[str(o)])
for o
in op.output]
1152 ssa.append((inputs, outputs))
1153 return ssa, blob_versions
1156 def get_undefined_blobs(ssa):
1158 Given a ssa in the format produced by get_ssa(), return a set of blobs that 1159 are used before they are defined, which corresponds to inputs at version 0. 1162 for inputs, _outputs
in ssa:
1163 undef_blobs |= set(name
for (name, ver)
in inputs
if ver == 0)
1167 def get_output_producers(ssa):
1169 Given a ssa in the format produced by get_ssa(), returns a map from 1170 versioned blob into the operator index that produces that version of 1171 the blob. A versioned blob is a tuple (blob_name, version). 1174 for i, (_inputs, outputs)
in enumerate(ssa):
1180 def get_op_ids_in_path(ssa, blob_versions, inputs, outputs):
1182 Given a ssa and blob_versions as produced by get_ssa(), returns the list 1183 of op indices that are necessary in order to generate the blobs in 1184 `outputs`, given blobs in `inputs`. 1185 Consider that the `inputs` are given in their latest version. 1187 inputs_set = set((str(i), blob_versions[str(i)])
for i
in inputs)
1188 producers = get_output_producers(ssa)
1189 queue = [(str(o), blob_versions[str(o)])
for o
in outputs]
1191 while len(queue) > 0:
1193 if (o
not in inputs_set)
and (o
in producers):
1194 op_id = producers[o]
1195 if op_id
not in used_op_ids:
1196 used_op_ids |= {op_id}
1197 inputs, _ = ssa[op_id]
1198 queue.extend(inputs)
1199 return sorted(used_op_ids)
1202 def recurrent_network_op_remap(op, prefix, blob_remap):
1206 op : Caffe2 operator (RecurrentNetworkOp or RecurrentNetworkGradientOp). 1207 prefix: this argument is not used in this function, just for legacy support. 1208 blob_remap : Dictionary that represents the map from old blob name to new. 1210 Updates blob names in arguments of RecurrentNetworkOp and 1211 RecurrentNetworkGradientOp to conform to cloned input and output of both 1212 operators and also makes sure names of locally generated blobs in arguments 1213 have the same prefix as the input and output of the operators. 1216 def get_remapped_str(blob_str):
1217 if isinstance(blob_str, binary_type):
1218 blob_str = blob_str.decode(
'utf-8')
1219 return blob_remap.get(blob_str, blob_str).encode(
'utf-8')
1221 for argument
in op.arg:
1222 if len(argument.strings) > 0:
1223 for i
in range(len(argument.strings)):
1224 argument.strings[i] = get_remapped_str(argument.strings[i])
1225 elif argument.name ==
'timestep':
1226 argument.s = get_remapped_str(argument.s)
1227 elif argument.name.endswith(
'step_net'):
1229 remap_proto(argument, blob_remap)
1232 def control_op_remap(op, prefix, blob_remap):
1235 net_arg_names = [
'then_net',
'else_net']
1237 net_arg_names = [
'loop_net',
'cond_net']
1238 for argument
in op.arg:
1239 if argument.name
in net_arg_names:
1240 assert argument.n, \
1241 "Expected non empty net in " + op.type +
"'s " + argument.name +
" argument" 1242 subnet =
Net(argument.n)
1243 remapped_subnet = subnet.Clone(
1244 name=(subnet._net.name
if subnet._net.name
else '') +
'_remapped',
1245 blob_remap=blob_remap)
1246 argument.n.CopyFrom(remapped_subnet.Proto())
1249 DEFAULT_REMAP_FUNCS = {
1250 'RecurrentNetwork': recurrent_network_op_remap,
1251 'RecurrentNetworkGradient': recurrent_network_op_remap,
1252 'If': control_op_remap,
1253 'While': control_op_remap,
1257 def remap_proto(argument, blob_remap):
1258 subnet =
Net(argument.n)
1260 cloned_sub_net = subnet.Clone(
1265 argument.n.CopyFrom(cloned_sub_net.Proto())
1268 def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None,
1271 Clone the given Net, binding its input schema to the given `inputs` record. 1272 Blob names defined by the net are prepended with the given `prefix`. 1275 net: the net to clone 1276 name: the name of the new net 1277 prefix: the prefix to append to local blobs 1278 blob_remap: (optional) dict with additional blob name remapping. 1279 inputs: (optional) input record that will provide actual input 1280 values for the cloned net. Must be compatible with the 1281 net's input schema or be a strict superset of it 1282 keep_schema: by default (True), the original schema will be kept and 1283 remapped accordingly. otherwise, the schema will be set as 1284 inputs or left empty if inputs is not given. 1286 Tuple (cloned_net, blob_remap) 1287 clone_net: the cloned Net 1288 blob_remap: a map from original blob names into remapped blob names 1291 assert isinstance(net, Net)
1292 if blob_remap
is None:
1294 if inputs
is not None:
1296 original = net.input_record()
1297 assert original
is not None 1299 diff = set(original.field_names()) - set(inputs.field_names())
1300 assert len(diff) == 0, (
1301 "Schemas don't match, extra fields {diff} found in the net {name}. " 1302 "original: {original}; inputs: {inputs}" 1304 diff=diff, name=net.Name(), original=original.field_names(),
1305 inputs=inputs.field_names()
1308 original_mapping = dict(zip(original.field_names(),
1309 original.field_blobs()))
1310 for fn, fb
in zip(inputs.field_names(), inputs.field_blobs()):
1311 if fn
in original_mapping:
1312 blob_remap[str(original_mapping[fn])] = str(fb)
1314 ssa, blob_versions = get_ssa(proto)
1315 undef_blobs = get_undefined_blobs(ssa)
1317 for blob
in viewkeys(blob_versions):
1318 if blob
in blob_remap:
1320 elif blob
in undef_blobs:
1321 blob_remap[blob] = blob
1323 blob_remap[blob] = prefix + blob
1324 cloned_net = net.Clone(name, blob_remap, keep_schema=keep_schema)
1325 if not keep_schema
and inputs:
1326 cloned_net.set_input_record(inputs)
1327 return cloned_net, blob_remap
1330 def _get_blob_ref(blob_name_or_ref):
1332 blob_name_or_ref
if isinstance(input, BlobReference)
1337 def _recover_record_by_prefix(names, prefix=''):
1339 Tries to recover record by taking a subset of blob names with 1340 a given prefix name and interpreting them as schema column names 1343 column_names = [name[len(prefix):]
for name
in names
1344 if name.startswith(prefix)]
1345 if not column_names:
1347 return schema.from_column_list(
1349 col_blobs=[_get_blob_ref(prefix + name)
for name
in column_names])
1353 _net_names_used = set()
1354 operator_registry_ = {}
1357 def current_prefix():
1359 builder = NetBuilder.current(required=
False)
1360 return builder.name
if builder
else '' 1363 def _get_next_net_name(basename):
1364 name = basename =
'/'.join(
1365 x
for x
in [Net.current_prefix(), basename]
if x
1368 while name
in Net._net_names_used:
1369 name = basename +
'_' + str(next_idx)
1371 Net._net_names_used |= set([name])
1378 name_or_proto: If a NetDef is provided, clone it. Otherwise, 1379 create an empty net with the given name. 1390 if type(name_or_proto)
is caffe2_pb2.NetDef:
1391 proto = name_or_proto
1394 self.
_net = caffe2_pb2.NetDef()
1395 self._net.CopyFrom(proto)
1397 existing_outputs = [list(op.output)
for op
in self._net.op]
1399 self._external_input_map.update(list(self._net.external_input))
1402 existing_names = set(
1404 [list(op.input)
for op
in self._net.op], []
1406 existing_outputs, []
1409 for outs
in existing_outputs:
1410 self._op_outputs.update(outs)
1412 prefix_len = len(self._net.name +
'_blob_')
1413 autogen_indices = []
1414 for s
in existing_names:
1415 if s.startswith(self._net.name +
'_blob_'):
1417 autogen_indices.append(int(s[prefix_len]))
1420 if len(autogen_indices):
1424 name = self._net.name
1426 name = name_or_proto
1427 self.
_net = caffe2_pb2.NetDef()
1431 self._net.name = Net._get_next_net_name(name)
1433 def AppendNet(self, net):
1434 assert isinstance(net, Net)
1435 for i
in net.Proto().external_input:
1437 i
not in self.
Proto().external_input
and 1440 self.
Proto().external_input.append(i)
1442 self.
Proto().external_output.extend(
1444 o
for o
in net.Proto().external_output
1445 if o
not in self.
Proto().external_output
1451 def LogInfo(self, *msg_or_blobs):
1452 for msg_or_blob
in msg_or_blobs:
1453 if not isinstance(msg_or_blob, BlobReference):
1454 blob = self.GivenTensorStringFill(
1456 shape=[], values=[msg_or_blob])
1459 self.Print(blob, [])
1463 Add `obj` to the list of attributes in this net under the given `name`. 1464 Attributes are user-defined objects and have no pre-defined semantics. 1470 Returns the list of attributes in this net for a given `name`. 1471 Attributes are user-defined objects added with `add_attribute'. 1473 return self._attr_dict.get(name, [])
1477 Adds a random seed to each op in the net. 1478 If sequence_seed is set, the i-th op has rand_seed=`seed + i` 1479 If seed_on_op_def is set, the op rand_seed=hash(str(op)) 1480 sequence_seed and seed_on_op_def cannot be both set to True. 1482 assert not (sequence_seed
and seed_on_op_def), (
1483 'sequence_seed and seed_on_op_def cannot be both set to True.')
1484 for i, op
in enumerate(self.
Proto().op):
1486 curr_seed = seed + i
1487 elif seed_on_op_def:
1488 curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
1491 op.device_option.random_seed = curr_seed
1494 return self._net.name
1499 def Const(self, array, blob_out=None, dtype=None):
1500 if isinstance(array, bool):
1501 return self.ConstantFill(
1504 dtype=DataType.BOOL,
1508 array = np.array(array)
1510 array = np.array(array, dtype=dtype)
1512 def do_set(operator):
1517 values=array.flatten().tolist())
1519 if array.dtype == np.int32:
1520 return do_set(self.GivenTensorIntFill)
1521 elif array.dtype == np.int64:
1522 return do_set(self.GivenTensorInt64Fill)
1523 elif array.dtype == np.str:
1524 return do_set(self.GivenTensorStringFill)
1525 elif array.dtype == np.bool:
1526 return do_set(self.GivenTensorBoolFill)
1528 return do_set(self.GivenTensorFill)
1532 Returns true if the given BlobReference is produced as output of 1533 an operator in this net, or if it is provided as an external input. 1542 Returns true iff the given BlobReference is used by any operator 1543 or this net, or if it is one of the external inputs of the net. 1545 blob_name = str(blob)
1546 for op
in self._net.op:
1547 for input
in op.input:
1548 if input == blob_name:
1554 Returns a set of blob names used in the net 1557 for op
in self._net.op:
1558 blob_names |= set(op.input)
1559 blob_names |= set(op.output)
1560 if self._net.external_input:
1561 blob_names |= set(self._net.external_input)
1562 if self._net.external_output:
1563 blob_names |= set(self._net.external_output)
1568 Given the name of a blob produced by this net, return a BlobReference 1569 to it. If the blob is not produced by any op in this net, 1572 blob_name = str(blob_name)
1574 raise KeyError(
'Net does not define blob %s' % blob_name)
1588 name: name of the cloned net 1589 blob_remap: optional map with list of blob names to replace 1590 op_id_mask: optional list of operator indices to include in 1591 the cloned net. If not provided, all ops are included. 1593 orig_remap_funcs = {}
if remap_funcs
is None else remap_funcs
1597 remap_funcs = DEFAULT_REMAP_FUNCS.copy()
1598 remap_funcs.update(orig_remap_funcs)
1600 new_proto = caffe2_pb2.NetDef()
1601 new_proto.CopyFrom(proto)
1602 new_proto.name = name
1604 if blob_remap
is None:
1606 if op_id_mask
is None:
1607 op_id_mask = list(range(0, len(proto.op)))
1609 def get_remapped_str(blob):
1610 blob_str = str(blob)
1611 return str(blob_remap.get(blob_str, blob_str))
1613 def remap_list(proto_list):
1614 new_list = [get_remapped_str(b)
for b
in proto_list]
1616 proto_list.extend(new_list)
1619 new_op = caffe2_pb2.OperatorDef()
1621 remap_list(new_op.input)
1622 remap_list(new_op.output)
1623 if new_op.type
in remap_funcs:
1624 remap_funcs[new_op.type](
1626 (name +
'/')
if name
else '',
1632 new_proto.op.extend([remap_op(proto.op[op_id])
for op_id
in op_id_mask])
1633 remap_list(new_proto.external_input)
1634 remap_list(new_proto.external_output)
1635 new_net =
Net(new_proto)
1640 new_net._input_record = schema.from_blob_list(
1644 for blob
in self._input_record.field_blobs()
1648 new_net._output_record = schema.from_blob_list(
1652 for blob
in self._output_record.field_blobs()
1661 Clone this net, including only ops that are necessary in order to 1662 compute `outputs` given `inputs`. Return references to the cloned 1663 outputs. Internal blobs (blobs that are produced and consumed inside 1664 the net but not used as outputs) will be remapped to avoid name 1668 name: the name of the cloned net 1669 inputs: map where the keys correspond to BlobReferences in the 1670 original net, and the values correspond to external inputs 1671 in the partially cloned net. If `inputs` is a list, don't 1673 outputs: outputs to be produced by the cloned net. 1676 Tuple (new_net, new_outputs) 1677 new_net: a new Net object. 1678 new_outputs: list of BlobReferences corresponding to the 1679 outputs produced by new_net. 1681 input_is_pair_list = isinstance(inputs, list)
and all(
1682 isinstance(i, tuple)
and len(i) == 2
for i
in inputs)
1684 inputs
if isinstance(inputs, (dict, OrderedDict))
else 1685 OrderedDict(inputs)
if input_is_pair_list
else 1686 OrderedDict(zip(inputs, inputs)))
1687 for output
in outputs:
1689 input_names = {str(k): str(v)
for k, v
in viewitems(inputs)}
1690 output_names = [str(o)
for o
in outputs]
1692 blob_versions = {str(i): 0
for i
in inputs}
1693 ssa, blob_versions = get_ssa(proto, blob_versions)
1694 used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
1695 disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs)
1696 assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
1697 'Cannot partially clone net: some of the ops required would ' +
1698 'generate the given input.')
1700 sub_ssa = [op
for i, op
in enumerate(ssa)
if i
in used_op_ids]
1701 undef_blobs = get_undefined_blobs(sub_ssa) - set(viewkeys(input_names))
1702 prefix = (name +
'/')
if name
else '' 1704 def remap(blob_name):
1705 if blob_name
in input_names:
1706 return input_names[blob_name]
1707 elif blob_name
in undef_blobs:
1710 return prefix + blob_name
1712 blob_mapping = {b: remap(b)
for b
in viewkeys(blob_versions)}
1713 new_net = self.
Clone(name, blob_mapping, used_op_ids, remap_funcs)
1715 blob_mapping[i]
for i
in viewkeys(input_names)] + list(undef_blobs)
1716 new_out = [blob_mapping[o]
for o
in output_names]
1717 del new_net.Proto().external_input[:]
1718 new_net.Proto().external_input.extend(new_in)
1719 new_net._external_input_map = set(list(new_in))
1720 del new_net.Proto().external_output[:]
1721 new_net.Proto().external_output.extend(new_out)
1722 return new_net, [new_net.GetBlobRef(o)
for o
in new_out]
1728 def PopulateProtoWithFileName(self):
1729 net_tb = workspace.operator_tracebacks.get(self.
Name(),
None)
1730 if net_tb
is not None:
1731 for idx, op
in enumerate(self.
Proto().op):
1733 op.name =
':'.join(map(str, net_tb[idx][0]))
1736 """Return the blob that has not been defined or registered in the 1737 current net. It returns `ScopedBlobReference(prefix)`, if it's valid, 1738 otherwise `ScopedBlobReference(prefix) + '_auto_' + ?`. Different calls 1739 is guaranteed to return blob with different names. 1741 output_blob_base = ScopedName(prefix)
1742 return self.
NextBlob(output_blob_base)
1745 """Return the blob that has not been defined or registered in the 1746 current net. It returns `BlobReference(prefix)`, if it's valid, 1747 otherwise `BlobReference(prefix) + '_auto_' + ?`. Different calls 1748 is guaranteed to return blob with different names.""" 1750 output_blob = output_blob_base
1754 output_blob = output_blob_base +
'_auto_' + str(index)
1757 self._registered_blob_names.add(str(output_blob))
1761 """Returns the next name to be used, if you do not want to explicitly 1762 name your blob. [Deprecated, use NextBlob, NextScopedBlob instead]""" 1764 output_name_base = self._net.name +
'/' + prefix
1765 output_name = output_name_base
1766 if output_id
is not None:
1767 output_name +=
':' + str(output_id)
1769 while self.
BlobIsDefined(str(ScopedBlobReference(output_name))):
1770 output_name = output_name_base +
'_' + str(index)
1771 if output_id
is not None:
1772 output_name +=
':' + str(output_id)
1777 return str(output_name)
1779 def _ExtendOps(self, new_ops):
1780 self._net.op.extend(new_ops)
1782 self._op_outputs.update([text_type(o)
for o
in op.output])
1784 def _CheckLookupTables(self):
1786 Called from unit tests to validate the internal lookup tables 1787 match the protobuf contents. 1789 test_op_outputs = set()
1790 for op
in self._net.op:
1792 test_op_outputs.add(o)
1794 test_external_inp = set()
1795 for inp
in self._net.external_input:
1796 test_external_inp.add(inp)
1798 assert test_op_outputs.difference(self.
_op_outputs) == set()
1801 def _InvalidateLookupTables(self):
1804 def _RecreateLookupTables(self):
1806 for op
in self._net.op:
1808 self._op_outputs.add(o)
1811 for inp
in self._net.external_input:
1812 self._external_input_map.add(inp)
1817 """Add the gradient for operators in the net. 1820 ys: a list or a dictionary specifying what blobs we want to compute 1821 derivatives of. If the input is a list, we will automatically 1822 generate their gradients with all-one values; if the input is a 1823 dictionary, for any dictionary entries that are not None, we will 1824 take the corresponding blobs as their gradients; for all those 1825 that are None, we will auto-fill them with 1. 1826 skip: skips the first n operators. This is provided mainly because a 1827 lot of nets may use the first few operators for data generation 1828 like stuff which really do not need to have gradients. 1831 returns a map from the blob name in the input network to a blob 1832 containing gradient or a GradientSlice in case of sparse gradient 1834 Currently, this is hard-coded for float operators if there are branches 1835 (i.e. a blob is used as input to multiple operators). This is because 1836 the gradient accumulation (Sum) is float only right now. 1839 grad_ops, input_to_grad = GradientRegistry.GetBackwardPass(
1840 self._net.op[skip:], ys)
1844 if workspace.IsImmediate():
1846 workspace.RunOperatorImmediate(op)
1848 return input_to_grad
1850 def AddExternalInput(self, *inputs):
1851 assert len(inputs) > 0
1853 for input
in inputs:
1854 input_name = str(input)
1856 'Net already contains an input named %s' % input_name)
1857 for input
in inputs:
1858 input_name = str(input)
1859 self._net.external_input.extend([input_name])
1860 self._external_input_map.update([input_name])
1861 refs.append(_get_blob_ref(input_name))
1863 return refs[0]
if len(refs) == 1
else refs
1865 def AddExternalOutput(self, *outputs):
1866 for output
in outputs:
1867 assert isinstance(output, BlobReference)
1869 for output
in outputs:
1870 self.
Proto().external_output.extend([str(output)])
1872 def AddScopedExternalInputs(self, *inputs):
1874 * [ScopedBlobReference(b)
for b
in inputs]
1876 if not isinstance(res, list):
1880 def AddScopedExternalOutputs(self, *outputs):
1882 * [ScopedBlobReference(b)
for b
in outputs]
1886 def AddObserver(self, observer_type):
1887 return C.add_observer_to_net(self._net.name, observer_type)
1889 def RemoveObserver(self, observer):
1890 C.remove_observer_from_net(self._net.name, observer)
1892 def NumObservers(self):
1893 return C.num_observers_on_net(self._net.name)
1896 def external_inputs(self):
1897 return [_get_blob_ref(x)
for x
in self._net.external_input]
1900 def external_outputs(self):
1901 return [_get_blob_ref(x)
for x
in self._net.external_output]
1903 def set_input_record(self, input_record):
1905 assert self.
_input_record is None or (input_record.has_blobs()
and 1906 set(input_record.field_blobs()) ==
1907 set(self._input_record.field_blobs())), (
1908 'Input schema cannot be reset')
1909 if not input_record.has_blobs():
1910 with NameScope(self.
Name()):
1914 for blob
in input_record.field_blobs():
1921 Tries to recover input record by taking a subset of external_inputs with 1922 a given prefix name and interpreting them as schema column names 1924 record = _recover_record_by_prefix(self._net.external_input, prefix)
1928 def set_output_record(self, record):
1930 set(record.field_blobs()) ==
1931 set(self._output_record.field_blobs())), (
1932 'Output schema cannot be reset')
1933 for blob
in record.field_blobs():
1934 assert self.
BlobIsDefined(blob),
"{} is not defined".format(blob)
1935 for blob
in record.field_blobs():
1941 Tries to recover out record by taking a subset of external_outputs with 1942 a given prefix name and interpreting them as schema column names 1944 record = _recover_record_by_prefix(self._net.external_output, prefix)
1948 def AppendOutputRecordField(self, field_name, record):
1951 'Tried to append to missing output record' 1953 for blob
in record.field_blobs():
1955 for blob
in record.field_blobs():
1958 (field_name, record)
1961 def input_record(self):
1964 def output_record(self):
1967 def AddExternalInputs(self, *inputs):
1970 def AddExternalOutputs(self, *outputs):
1973 def DeduplicateGradientSlices(self, g, aggregator='sum'):
1974 assert isinstance(g, GradientSlice)
1975 unique, remapping = self.Unique([g.indices], 2, engine=
'SparseHash')
1976 if aggregator.lower() ==
'sum':
1977 new_g = self.UnsortedSegmentSum([g.values, remapping], 1)
1978 elif aggregator.lower() ==
'mean':
1979 new_g = self.UnsortedSegmentMean([g.values, remapping], 1)
1981 raise ValueError(
'{} is not supported'.format(aggregator))
1982 return GradientSlice(indices=unique, values=new_g)
1985 """A convenient function to run everything on the GPU.""" 1986 device_option = caffe2_pb2.DeviceOption()
1987 device_option.device_type = caffe2_pb2.CUDA
1988 device_option.cuda_gpu_id = gpu_id
1989 self._net.device_option.CopyFrom(device_option)
1991 for op
in self._net.op:
1994 """A convenient function to run everything using MKLDNN.""" 1995 device_option = caffe2_pb2.DeviceOption()
1996 device_option.device_type = caffe2_pb2.MKLDNN
1997 self._net.device_option.CopyFrom(device_option)
1999 def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, **kwargs):
2000 """A helper function to create an operator and add it to self. 2002 inputs = _RectifyInputOutput(inputs)
2003 for input
in inputs:
2005 assert input.Net() != self
2010 outputs = self.
NextName(prefix=op_type)
2011 elif type(outputs)
is int:
2015 self.
NextName(prefix=op_type, output_id=i)
2016 for i
in range(outputs)]
2017 outputs = _RectifyInputOutput(outputs, net=self)
2018 op = CreateOperator(op_type, inputs, outputs, **kwargs)
2021 workspace.operator_tracebacks[self.
Name()][
2022 len(self._net.op) - 1] = _extract_stacktrace()
2024 if len(op.output) == 0:
2026 elif len(op.output) == 1:
2031 def __getattr__(self, op_type):
2032 if op_type.startswith(
'__'):
2033 raise AttributeError(
'Attribute {} not found.'.format(op_type))
2034 if not IsOperator(op_type)
and not IsOperatorWithEngine(op_type,
"CUDNN"):
2035 raise AttributeError(
2036 'Method ' + op_type +
' is not a registered operator.' +
2037 ' Did you mean: [' +
2038 ",".join(workspace.C.nearby_opnames(op_type)) +
']' 2041 op_type, *args, **kwargs)
2044 additional_methods = [
2046 for op
in _REGISTERED_OPERATORS
2047 if '_ENGINE_' not in op]
2048 return sorted(set(chain(
2050 viewkeys(self.__dict__),
2058 python_func_type=
None,
2059 pass_workspace=
False,
2060 grad_output_indices=
None,
2061 grad_input_indices=
None 2064 Registers and returns a python operator. 2066 `f` and `grad_f` can be one of the following: 2067 - a function with signature (inputs, outputs), where inputs and 2068 outputs are a list of CPUTensor objects. This function will be 2069 called from C++ everytime the operator is executed. 2070 - a tuple (func, args, kwargs), here `func` is a callable, args is 2071 an argument list, and kwargs is a dict list. The call: 2072 f = func(*args, kwargs) 2073 will be performed locally at node initialization time, on all of 2074 the nodes of the job, returning `f`, a callable that will be used 2075 as the python operator function to be called during Net execution. 2076 This is to be used when using python operator in a distributed 2077 context, and allows to create and keep local python state across 2078 calls to the operator. 2080 `python_func_type` is a type of an object that constructed as 2081 python_func_type(f) and provides an implementation to forward and 2082 backward functions. Its useful in such a case where users needs 2083 a statefull PythonOp (ex: use autograd for computing grad_f). 2085 If `pass_workspace` is True, the signature is changed to 2086 (inputs, outputs, workspace) where `workspace` is the workspace the op 2087 is going to run on. This is potentially dangerous (as the op can 2088 manipulate the workspace directly), use on your own risk. 2090 If a gradient function is specified (`grad_f`), by default its inputs 2091 will be: (1) all inputs to `f`, (2) followed by all outputs of `f`, (3) 2092 and then all gradient outputs of `f`. The outputs of `grad_f` will be 2093 (by default) all gradient inputs to `f`. If a subset of the gradient 2094 outputs or gradient inputs is desired instead, then the subsets can be 2095 specified by providing `grad_output_indices` and/or `grad_input_indices` 2096 which identify the indices of `f`'s inputs and outputs which have 2099 assert(IsOperator(
'Python'))
2101 def make_builder(t):
2102 if not isinstance(t, tuple):
2104 assert len(t) == 3,
'Expected builder tuple (func, args, kwargs)' 2105 func, args, kwargs = t
2106 normalized = (func, tuple(args), dict(kwargs))
2107 return pickle.dumps(normalized)
2109 f_builder = make_builder(f)
2110 grad_f_builder = make_builder(grad_f)
2112 assert (
not grad_f)
or ((
not f_builder) == (
not grad_f_builder)), (
2113 'A tuple has to be passed to both f and grad_f or neither.')
2117 core_kwargs[
'pickled_builder'] = f_builder
2118 core_kwargs[
'pickled_grad_builder'] = grad_f_builder
2119 core_kwargs[
'pass_workspace'] = pass_workspace
2121 core_kwargs[
'token'] = _RegisterPythonImpl(
2122 f, grad_f, python_func_type, pass_workspace=pass_workspace)
2124 grad_output_indices = grad_output_indices
or []
2125 grad_input_indices = grad_input_indices
or []
2128 grad_output_indices=grad_output_indices,
2129 grad_input_indices=grad_input_indices,
2131 **dict(chain(viewitems(kwargs), viewitems(core_kwargs)))
2134 def is_external_input(self, blob):
2138 def extend_ops(self, new_ops):
2142 def copy_func_between_devices(src, dst):
2143 CPU = caffe2_pb2.CPU
2144 CUDA = caffe2_pb2.CUDA
2146 if src.device_type == CPU
and dst.device_type == CPU:
2149 if src.device_type == CUDA
and dst.device_type == CUDA:
2150 if src.cuda_gpu_id == dst.cuda_gpu_id:
2153 def fun(net, *args, **kw):
2154 with DeviceScope(dst):
2155 return net.Copy(*args, **kw)
2158 if src.device_type == CUDA
and dst.device_type == CPU:
2159 def fun(net, *args, **kw):
2160 with DeviceScope(src):
2161 return net.CopyGPUToCPU(*args, **kw)
2164 if src.device_type == CPU
and dst.device_type == CUDA:
2165 def fun(net, *args, **kw):
2166 with DeviceScope(dst):
2167 return net.CopyCPUToGPU(*args, **kw)
2170 raise ValueError(
'Non-supported devices: %s and %s' % (src, dst))
2173 def device_equal(src, dst):
2175 We are using this fucntion instead of == operator because optional-value 2176 comparison between empty device_options and {device_type:0, cuda_gpu_id:0} 2177 returns not equal in some cases. 2179 return src.device_type == dst.device_type
and src.cuda_gpu_id == dst.cuda_gpu_id
2182 def update_placeholder_op_output(op, blob_to_device):
2184 Placeholder ops (for e.g. Recv) always runs on CPU. So ensure their 2185 output blobs reside on CPU. 2188 for output
in op.output:
2189 blob_dev = blob_to_device[output]
2190 if blob_dev.device_type != caffe2_pb2.CPU:
2192 outputs.append(output)
2194 op.output.extend(outputs)
2198 def __init__(self, blob, device):
2202 def __eq__(self, other):
2203 return self.
blob == other.blob
and self.
device == other.device
2209 def InjectCrossDeviceCopies(net, blob_to_device=None, blob_remap=None,
2210 placeHolderOps=
None):
2212 Injecting Copy functions between device within a net. Users can provide 2213 a net with part of operators using different device_options. This method 2214 will automatically create a new net with Copy ops inserted in it. 2217 blob_to_device: If not None, it is a map of blobs and their device locations. 2218 blob_remap: If not None, it is a map from a pair (blob, device) to 2219 the name of the blob in the given device. Blobs found in this 2220 map are assumed to be cached and don't need to be copied. 2222 new_net: A new net with CopyCPUToGPU inserted with correct device option 2224 required_external_to_device: 2225 A mapping between unresolved external inputs and their 2226 required device options. 2228 1. every external inputs of this net is already in blob_to_device! 2229 2. if not, this function will use net device option 2231 new_net = net.Clone(net._net.name +
'_cross_device', keep_schema=
True)
2232 del new_net._net.op[:]
2233 if blob_to_device
is None:
2236 if blob_remap
is None:
2239 net_option = net._net.device_option
or caffe2_pb2.DeviceOption()
2243 all_remaps = defaultdict(list)
2244 for entry, mapped_blob
in blob_remap.items():
2245 all_remaps[entry.blob].append(mapped_blob)
2246 mapped_external_inputs = []
2247 for input
in new_net._net.external_input:
2248 mapped_external_inputs.extend(all_remaps.get(input)
or [])
2249 new_net._net.external_input.extend(mapped_external_inputs)
2251 for op
in net._net.op:
2257 if placeHolderOps
is not None and op.type
in placeHolderOps:
2258 input_dev, output_dev = InferOpDeviceAsBlobDevices(op)
2260 input_dev, output_dev = InferOpBlobDevices(op)
2262 for dev, input
in zip(input_dev, op.input):
2263 assert net.BlobIsDefined(input), \
2264 "input {} should be defined in the net.".format(input)
2265 if input
not in blob_to_device:
2266 if net.is_external_input(input):
2267 blob_to_device[input] = net_option
2269 raise AttributeError(
2270 "No device information found for blob {}.".
2274 if not device_equal(blob_to_device[input], dev):
2277 blob_to_device[blob_remap[
RemapEntry(input, dev)]] == dev):
2278 temp_remap[input] = blob_remap[
RemapEntry(input, dev)]
2281 copy_func = copy_func_between_devices(
2282 blob_to_device[input], dev
2285 def _gen_new_name(blob, device_option):
2286 CPU = caffe2_pb2.CPU
2287 CUDA = caffe2_pb2.CUDA
2288 if device_option.device_type == CPU:
2290 elif device_option.device_type == CUDA:
2291 suffix =
'_cuda_' + str(device_option.cuda_gpu_id)
2294 "Unknown device type: {}".
2295 format(device_option.device_type)
2297 return blob + suffix
2299 new_name = _gen_new_name(input, dev)
2300 copy_func(new_net, input, new_name)
2301 blob_remap[
RemapEntry(input, dev)] = new_name
2302 temp_remap[input] = new_name
2303 blob_to_device[new_name] = dev
2305 if placeHolderOps
is not None and op.type
in placeHolderOps:
2306 update_placeholder_op_output(op, blob_to_device)
2311 for dev, output
in zip(output_dev, op.output):
2312 if output
in blob_to_device
and (
2313 output
not in op.input
and 2314 not device_equal(blob_to_device[output], dev)
2317 "In-place blob: {} is not supported between operators " 2318 "with different device option previous:{} now: {}. " 2319 "Failed op:\n {}".format(
2320 output, blob_to_device[output], dev, op
2323 new_op = caffe2_pb2.OperatorDef()
2326 new_list = [temp_remap.get(b, b)
for b
in new_op.input]
2328 new_op.input.extend(new_list)
2331 original_inputs = list(op.input)
2332 for i, out
in enumerate(new_op.output):
2334 input_idx = original_inputs.index(out)
2335 new_op.output[i] = new_op.input[input_idx]
2339 blob_to_device.update(
2340 {o: d
for d, o
in zip(output_dev, new_op.output)})
2341 new_net.extend_ops([new_op])
2343 return new_net, blob_to_device
2346 def InjectDeviceCopiesAmongNets(nets, blob_to_device_init=None):
2348 Takes in a list of nets. They usually represent your whole execution graph. 2349 This function will insert cross device copy functions to all nets, and resolve 2350 inter-net external inputs dependencies. This method will insert Copy funcitons if 2351 external inputs of a net is produced on different device than it is required. 2353 nets: a list of nets 2355 new_nets: a list of new nets with device difference solved. 2357 Some notes from wyiming: 2358 1. You MUST pass nets in execution order. e.g. [train_init, train] 2360 assert isinstance(nets, list), \
2361 "nets {} should be a list of nets.".format(str(nets))
2362 assert all(isinstance(net, Net)
for net
in nets), \
2363 "nets {} should be a list of nets.".format(str(nets))
2365 blob_to_device = blob_to_device_init
or {}
2370 new_net, blob_to_device = InjectCrossDeviceCopies(
2372 blob_to_device=blob_to_device,
2373 blob_remap=blob_remap,
2375 new_nets.append(new_net)
2377 return new_nets, blob_to_device
2380 def InjectDeviceCopiesAmongNetsWithoutB2D(nets, blob_to_device_init=None):
2381 new_nets, _ = InjectDeviceCopiesAmongNets(nets, blob_to_device_init)
2385 def get_net_name(netlike):
2386 if isinstance(netlike, Net):
2387 return netlike.Proto().name
2388 elif isinstance(netlike, caffe2_pb2.NetDef):
2394 def output_to_list(op_output):
2396 Ensures that the output of an operator is a list. 2397 Use when an operator has a variable number of outputs, but a list of 2398 outputs is desired even when number of outputs is 1. 2401 op_output: Either a BlobReferenece or an iterable of BlobReferences. 2404 A list of BlobReferences. 2406 assert type(op_output)
in (list, tuple, BlobReference)
2409 if isinstance(op_output, BlobReference)
else list(op_output))
2412 def _add_net_to_dict(net_dict, net):
2413 name = get_net_name(net)
2414 if name
in net_dict:
2415 assert net_dict[name]
is None or net == net_dict[name], (
2416 'Different nets with same name: ' + name)
2419 net_dict[name] = net
if isinstance(net, Net)
else None 2424 _step_names_used = set()
2427 def _get_next_step_name(basename):
2430 while name
in ExecutionStep._step_names_used:
2431 name = basename +
'_' + str(next_idx)
2433 ExecutionStep._step_names_used |= set([name])
2436 def __init__(self, name, nets=None, num_iter=None):
2437 self.
_step = caffe2_pb2.ExecutionStep()
2438 self._step.name = name
or ExecutionStep._get_next_step_name(
'step')
2442 if nets
is not None:
2443 if type(nets)
is Net:
2446 if _add_net_to_dict(self.
_net_dict, net):
2447 self._step.network.extend([get_net_name(net)])
2448 if num_iter
is not None:
2449 self._step.num_iter = num_iter
2451 def get_net(self, name):
2455 return self._step.name
2458 return self._step.name
2460 def _assert_can_mutate(self):
2462 'Cannot mutate a step that has already been added to a plan/step.')
2464 def _notify_is_used(self):
2471 return self._step.network
is not None and (
2472 len(self._step.network) > 0)
2474 def HasSubsteps(self):
2475 return self._step.substep
is not None and (
2476 len(self._step.substep) > 0)
2484 def SetIter(self, num_iter):
2486 self._step.num_iter = num_iter
2488 def SetCreateWorkspace(self, create_workspace):
2490 self._step.create_workspace = create_workspace
2492 def SetNumConcurrentInstances(self, num_concurrent_instances):
2494 self._step.num_concurrent_instances = num_concurrent_instances
2496 def SetOnlyOnce(self, only_once):
2498 self._step.only_once = only_once
2500 def SetShouldStopBlob(self, should_stop_blob):
2501 assert isinstance(should_stop_blob, BlobReference), (
2502 "expects BlobReference here, got {}".format(type(should_stop_blob)))
2504 self._step.should_stop_blob = str(should_stop_blob)
2508 Run this step every interval millisecods, as long as its 2509 siblings are still running. It is guaranteed that, after all 2510 siblings finish, this step will run at least one. 2512 This property is ignored for top-level ExecutionSteps. 2514 self._step.run_every_ms = interval
2517 """ DEPRECATED. Use RunEveryMillis instead. """ 2519 _add_net_to_dict(self.
_net_dict, report_net)
2520 self._step.report_net = get_net_name(report_net)
2521 self._step.report_interval = report_interval
2523 def AddSubstep(self, substep):
2525 assert not self.
HasNets(),
'Cannot have both network and substeps.' 2526 if isinstance(substep, ExecutionStep):
2527 substep._notify_is_used()
2528 if not substep.HasNets()
and not substep.HasSubsteps():
2530 for net
in substep.Nets():
2532 self._substeps.append(substep)
2533 proto = substep.Proto()
2536 self._step.substep.add().CopyFrom(proto)
2539 def SetConcurrentSubsteps(self, concurrent_substeps):
2541 assert not self.
HasNets(),
'Cannot have both network and substeps.' 2542 self._step.concurrent_substeps = concurrent_substeps
2544 def AddNet(self, net):
2546 assert not self.
HasSubsteps(),
'Cannot have both network and substeps.' 2547 assert isinstance(net, Net)
2549 self._step.network.extend([get_net_name(net)])
2554 Return the list of all attributes under the given `name`, present in 2555 all of the nets used in this execution step and its children. 2560 for attr
in net.get_attributes(name)
2566 Create ExecutionStep from ExecutionStep protobuf recursively 2568 assert isinstance(step_proto, caffe2_pb2.ExecutionStep)
2569 assert (len(step_proto.network) > 0
and len(step_proto.substep) == 0)
or \
2570 (len(step_proto.network) == 0
and len(step_proto.substep) > 0)
2573 if len(step_proto.substep) > 0:
2574 for substep_proto
in step_proto.substep:
2575 steps_or_nets.append(ExecutionStep.create_from_proto(
2576 substep_proto, net_obj_dict, net_proto_dict))
2578 for net_name
in step_proto.network:
2579 if net_name
not in net_obj_dict:
2580 assert net_name
in net_proto_dict
2581 net =
Net(net_proto_dict[net_name])
2582 net_obj_dict[net_name] = net
2583 net = net_obj_dict[net_name]
2584 assert isinstance(net, Net)
2585 steps_or_nets.append(net)
2587 num_iter = step_proto.num_iter
if step_proto.HasField(
'num_iter')
else None 2588 concurrent_substeps = step_proto.concurrent_substeps
if\
2589 step_proto.HasField(
'concurrent_substeps')
else None 2590 should_stop_blob =
BlobReference(step_proto.should_stop_blob)
if\
2591 step_proto.HasField(
'should_stop_blob')
else None 2592 only_once = step_proto.only_once
if\
2593 step_proto.HasField(
'only_once')
else None 2594 num_concurrent_instances = step_proto.num_concurrent_instances
if\
2595 step_proto.HasField(
'num_concurrent_instances')
else None 2596 create_workspace = step_proto.create_workspace
if\
2597 step_proto.HasField(
'create_workspace')
else None 2598 run_every_ms = step_proto.run_every_ms
if\
2599 step_proto.HasField(
'run_every_ms')
else None 2601 return execution_step(
2606 report_interval=
None,
2607 concurrent_substeps=concurrent_substeps,
2608 should_stop_blob=should_stop_blob,
2609 only_once=only_once,
2610 num_concurrent_instances=num_concurrent_instances,
2611 create_workspace=create_workspace,
2612 run_every_ms=run_every_ms)
2615 def add_nets_in_order(step, net_list):
2616 proto = step.Proto()
2617 for substep
in step.Substeps():
2618 add_nets_in_order(substep, net_list)
2619 for net
in proto.network:
2620 if net
not in net_list:
2621 net_list.append(net)
2625 if proto.report_net
and proto.report_net
not in net_list:
2626 net_list.append(proto.report_net)
2631 def __init__(self, name_or_step):
2632 self.
_plan = caffe2_pb2.PlanDef()
2635 if isinstance(name_or_step, ExecutionStep):
2636 self._plan.name = name_or_step.Name()
2638 elif isinstance(name_or_step, basestring):
2639 self._plan.name = name_or_step
2641 raise ValueError(
'name_or_step must be a string or ExecutionStep')
2644 return self._plan.name
2649 def AddNets(self, nets):
2651 if _add_net_to_dict(self.
_net_dict, net):
2652 assert isinstance(net, Net)
2653 self._plan.network.add().CopyFrom(net.Proto())
2658 def AddStep(self, step):
2659 assert isinstance(step, ExecutionStep)
2660 step._notify_is_used()
2661 if not step.HasNets()
and not step.HasSubsteps():
2663 self._plan.execution_step.add().CopyFrom(step.Proto())
2664 self._steps.append(step)
2667 add_nets_in_order(step, net_list)
2668 self.
AddNets([step.get_net(n)
for n
in net_list])
2675 Return the list of all attributes under the given `name`, present in 2676 all of the nets used in this plan. 2681 for attr
in net.get_attributes(name)
2685 def create_from_proto(cls, plan_proto):
2686 assert isinstance(plan_proto, caffe2_pb2.PlanDef)
2687 plan =
Plan(plan_proto.name)
2688 plan._plan.CopyFrom(plan_proto)
2692 for net_proto
in plan_proto.network:
2693 assert net_proto.name
not in net_proto_dict
2694 net_proto_dict[net_proto.name] = net_proto
2696 for step_proto
in plan_proto.execution_step:
2697 step = ExecutionStep.create_from_proto(
2698 step_proto, net_obj_dict, net_proto_dict)
2704 def to_execution_step(step_or_nets, default_name=None):
2706 if isinstance(step_or_nets, ExecutionStep):
2710 if not default_name
and hasattr(step_or_nets,
'name'):
2711 default_name = step_or_nets.name
2712 if isinstance(step_or_nets, NetBuilder):
2713 stop_blob = step_or_nets._stop_blob
2714 step_or_nets = step_or_nets.get()
2715 return execution_step(
2716 default_name, step_or_nets, should_stop_blob=stop_blob)
2719 def execution_step(default_name,
2723 report_interval=
None,
2724 concurrent_substeps=
None,
2725 should_stop_blob=
None,
2727 num_concurrent_instances=
None,
2728 create_workspace=
False,
2731 Helper for creating an ExecutionStep. 2732 - steps_or_nets can be: 2737 - list<ExecutionStep> 2738 - should_stop_blob is either None or a scalar boolean blob. 2739 - This blob is checked AFTER every substeps/subnets. 2740 - If specified and true, then this step will return immediately. 2741 - Be sure to handle race conditions if setting from concurrent threads. 2742 - if no should_stop_blob or num_iter is provided, defaults to num_iter=1 2744 assert should_stop_blob
is None or num_iter
is None, (
2745 'Cannot set both should_stop_blob and num_iter.')
2746 if should_stop_blob
is None and num_iter
is None:
2750 if should_stop_blob
is not None:
2751 step.SetShouldStopBlob(should_stop_blob)
2752 if num_iter
is not None:
2753 step.SetIter(num_iter)
2754 if only_once
is not None:
2755 step.SetOnlyOnce(only_once)
2756 if concurrent_substeps
is not None:
2757 step.SetConcurrentSubsteps(concurrent_substeps)
2758 if report_net
is not None:
2759 assert report_interval
is not None 2760 step.SetReportNet(report_net, report_interval)
2761 if num_concurrent_instances
is not None:
2762 step.SetNumConcurrentInstances(num_concurrent_instances)
2763 if create_workspace:
2764 step.SetCreateWorkspace(
True)
2766 step.RunEveryMillis(run_every_ms)
2768 if isinstance(steps_or_nets, ExecutionStep):
2769 step.AddSubstep(steps_or_nets)
2770 elif isinstance(steps_or_nets, Net):
2771 step.AddNet(steps_or_nets)
2772 elif isinstance(steps_or_nets, list):
2773 if all(isinstance(x, Net)
for x
in steps_or_nets):
2774 for x
in steps_or_nets:
2777 for x
in steps_or_nets:
2778 step.AddSubstep(to_execution_step(x))
2781 'steps_or_nets must be a step, a net, or a list of nets or steps.')
2785 def scoped_execution_step(name, *args, **kwargs):
2786 """Same as execution_step() except that the step name is scoped.""" 2787 default_name = ScopedName(name)
if name
else name
2788 return execution_step(default_name, *args, **kwargs)
2791 def _extract_stacktrace():
2793 This function extracts stacktrace without file system access 2794 by purely using sys._getframe() and removes part that belongs to 2795 this file (core.py). We are not using inspect module because 2796 its just a wrapper on top of sys._getframe() whos 2797 logis is based on accessing source files on disk - exactly what 2798 we are trying to avoid here. Same stands for traceback module 2800 The reason for file system access avoidance is that 2801 if code is located on an NFS, file access might be slow 2803 Function returns a list of tuples (file_name, line_number, function) 2809 frame = sys._getframe(3)
2814 result.append((frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name))
2815 frame = frame.f_back
2819 SetPerOpEnginePref = C.set_per_op_engine_pref
2820 SetGlobalEnginePref = C.set_global_engine_pref
2821 SetEnginePref = C.set_engine_pref
2822 SetOpEnginePref = C.set_op_engine_pref
def add_attribute(self, name, obj)
def BuildGradientGenerators(self, fwd_op_idx, gradient_ops, g_output, g_input)
def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, kwargs)
def _RecreateLookupTables(self)
def get_attributes(self, name)
def recover_input_record_by_prefix(self, prefix)
def AddExternalOutput(self, outputs)
def NextBlob(self, prefix='unnamed')
def get_all_attributes(self, name)
def RegisterGradient(cls, op_type)
def external_inputs(self)
dictionary gradient_registry_
def SanityCheck(self, operators)
def _assert_can_mutate(self)
def DoGradientAccumulation(self, fwd_op_idx)
def _VerifyGradientGenerators(self, generator)
def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False)
def _CreateAndAddToNet(self, op_type, inputs=None, args, kwargs)
def _AppendAutoGradGenerator(self, y, grad, autograd_op)
def __init__(self, name_or_proto)
def _MakeSparseSumOps(self, generators, out_base_name)
def AddGradientOperators(self, ys, skip=0)
def GetBackwardPass(self, ys)
def _MakeSumOps(self, input_name, input_version)
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False)
def Clone(self, name, blob_remap=None, op_id_mask=None, remap_funcs=None, keep_schema=True)
def get_all_attributes(self, name)
def _DisambiguateGradOpOutput(self, grad_op, idx, cnt)
def CheckGradientOperatorInput(self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs)
def RunEveryMillis(self, interval)
def NextName(self, prefix=None, output_id=None)
def _ExtendOps(self, new_ops)
def __getattr__(self, op_type)
def set_output_record(self, record)
def _GetSumOpOutputName(self, generator, input_name)
def _SetSumOpsDeviceOption(self, sum_ops, generators)
def AppendSparseGenerators(self, sparse_generators)
def create_from_proto(cls, step_proto, net_obj_dict, net_proto_dict)
def __init__(self, name, net=None)
def BlobIsDefined(self, blob)
def _MakeDenseSumOps(self, generators, out_base_name)
def _GetGradientForOpCC(cls, op_def, g_output)
def AddExternalInput(self, inputs)
def NextScopedBlob(self, prefix='unnamed')
def _GetInitGradients(self, ys)
def Python(self, f, grad_f=None, python_func_type=None, pass_workspace=False, grad_output_indices=None, grad_input_indices=None)
def _CheckSumOpsConflict(self, out_base_name, g)
def ClonePartial(self, name, inputs, outputs, remap_funcs=None)
def GetBlobRef(self, blob_name)
def _GenerateGradientsForForwardOp(self, forward_op_idx, input_to_grad)
def _InvalidateLookupTables(self)
def SetReportNet(self, report_net, report_interval)
def set_input_record(self, input_record)
def GetBackwardPass(cls, operators, ys, ys_generate_gradient=False)
def recover_output_record_by_prefix(self, prefix)