3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 parameter_sharing_context,
19 from future.utils
import viewitems, viewkeys
20 from itertools
import chain
27 _known_working_ops = [
75 """A helper model so we can manange models more easily. It contains net def 76 and parameter storages. You can add an Operator yourself, e.g. 78 model = model_helper.ModelHelper(name="train_net") 79 # init your weight and bias as w and b 80 w = model.param_init_net.XavierFill(...) 81 b = model.param_init_net.ConstantFill(...) 82 fc1 = model.FC([input, w, b], output, **kwargs) 84 or you can use helper functions in brew module without manually 85 defining parameter initializations and operators. 87 model = model_helper.ModelHelper(name="train_net") 88 fc1 = brew.fc(model, input, output, dim_in, dim_out, **kwargs) 92 def __init__(self, name=None, init_params=True, allow_not_known_ops=True,
93 skip_sparse_optim=
False, param_model=
None, arg_scope=
None):
94 self.
name = name
or "model" 97 if param_model
is not None:
100 self.
params = param_model.params
121 'cudnn_exhaustive_search':
False,
123 if arg_scope
is not None:
126 self._arg_scope.update(arg_scope)
135 def _infer_param_shape(self, param):
136 for op
in self.param_init_net.Proto().op:
137 if str(param)
in op.output:
139 if arg.name ==
"shape":
140 return list(arg.ints)
143 def _update_param_info_deprecated(self):
148 "Param %s must be a BlobReference!" % str(param))
149 self._param_info_deprecated.append(parameter_info.ParameterInfo(
154 info.grad = self.param_to_grad.get(info.name)
156 def _normalize_tags(self, tags):
158 return set(tags)
if isinstance(tags, list)
else set([tags])
162 Creates parameter with a given name and initializer. 164 If param_name is instance of BlobRefernce - then this blob will be used 165 to store parameter (no any logic will affect it's location). 167 If param_name is instance of a string type, then the final blob will 168 be created in the CurrentNameScope with the respect of all parameter 169 sharing logic, i.e. 'resolved_name_scope/param_name'. 171 Parameter sharing logic is going to override CurrentNameScope accoring 172 to the rules that are specified through ParameterSharing contexts, 173 all ParameterSharing contexts are applied recursively until there are no 174 extra overrides present, where on each step the best match will be 177 The following examples should clarify the way ParameterSharing logic 180 As an example if this function is called with parameter 'w': 181 a. Call from some scope 'global_scope' with no Parameter sharing: 183 b. Call from scope 'scope_b', with override {'scope_b': 'scope_a'}: 185 c. Call from scope 'scope_a', with override {'scope_a': ''}: 187 d. Call from scope 'scope_b/shared', with overrides 188 {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}: 190 d. Call from scope 'scope_b/unshared', with overrides 191 {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}: 198 param_name = str(param_name)
199 elif isinstance(param_name, six.string_types):
202 param_name = parameter_sharing_context.get_parameter_name(
205 raise "Unsupported type for param_name" 211 param_info = initializer.create_param(
216 optim_context = OptimizerContext.current()
218 if optim_context.has_optimizer(tag):
220 param_info.optimizer = optim_context.get_optimizer(tag)
221 if not param_info.optimizer
and optim_context.has_optimizer(DEFAULT_OPTIM):
222 param_info.optimizer = optim_context.get_optimizer(DEFAULT_OPTIM)
224 reg_context = RegularizerContext.current()
225 param_info.regularizer = reg_context
231 return param_info.blob
233 def get_param_info(self, param):
235 "Param {} is not a BlobReference".format(param)
236 return self._parameters_info.get(param,
None)
240 def add_param_DEPRECATED(self, param, key=None, shape=None, length=None):
241 logging.warning(
"add_param method is DEPRECATED")
244 if key
is not None and self.net.input_record()
is not None:
245 idx = self.net.input_record().field_blobs().index(key)
246 key = self.net.input_record().field_names()[idx]
249 raise ValueError(
"Param %s must be a BlobReference!" % str(param))
250 self._param_info_deprecated.append(parameter_info.ParameterInfo(
260 def param_info(self, grad_type=None, id=None):
261 logging.info(
"param_info method is DEPRECATED")
264 assert grad_type
is None 266 assert info.param_id == id
268 elif grad_type
is not None:
271 if info.grad_type() == grad_type]
275 def AddParameter(self, param, tags=None):
278 if parameter_info.ParameterTags.COMPUTED_PARAM
in tags:
279 self._computed_params.append(param)
281 self.params.append(param)
283 if parameter_info.ParameterTags.WEIGHT
in tags:
284 self.weights.append(param)
285 if parameter_info.ParameterTags.BIAS
in tags:
286 self.biases.append(param)
289 def _NormalizeNamescope(namescope):
290 if namescope
is None:
291 return scope.CurrentNameScope()
292 elif namescope ==
'' or namescope.endswith(scope._NAMESCOPE_SEPARATOR):
295 return namescope + scope._NAMESCOPE_SEPARATOR
299 Returns the params in current namescope 301 namescope = ModelHelper._NormalizeNamescope(namescope)
308 if p.GetNameScope().startswith(namescope)
311 return [p
for p
in self.
params if 312 p.GetNameScope().startswith(namescope)]
315 return self.net.Proto()
318 return self.param_init_net.Proto()
320 def RunAllOnGPU(self, *args, **kwargs):
321 self.param_init_net.RunAllOnGPU(*args, **kwargs)
322 self.net.RunAllOnGPU(*args, **kwargs)
324 def CreateDB(self, blob_out, db, db_type, **kwargs):
325 dbreader = self.param_init_net.CreateDB(
326 [], blob_out, db=db, db_type=db_type, **kwargs)
329 def AddGradientOperators(self, *args, **kwargs):
331 raise RuntimeError(
"You cannot run AddGradientOperators twice.")
335 self.
grad_map = self.net.AddGradientOperators(*args, **kwargs)
340 for param, grad
in self.param_to_grad.items():
343 param_info.grad = grad
355 Given a list of parameters returns a dict from a parameter 356 to a corresponding gradient 361 raise RuntimeError(
"You need to run AddGradientOperators first.")
366 param_to_grad[p] = self.
grad_map[str(p)]
371 Returns a map for param => grad. 372 If params is not specified, all parameters will be considered. 375 raise RuntimeError(
"Need to call AddGradientOperators first")
382 self.
get_param_info(param)
for param, grad
in viewitems(param_to_grad)
385 not isinstance(grad, core.GradientSlice)
391 Check for duplicate params 393 params_list = [str(p)
for p
in self.
params]
394 params_set = set(params_list)
397 if len(params_set) != len(params_list):
398 params_list = sorted(params_list)
399 for j, p
in enumerate(params_list):
400 if j > 0
and params_list[j - 1] == p:
408 assert dupes == [],
"Duplicate params: {}".format(dupes)
412 Returns the computed params in current namescope. 'Computed params' 413 are such parameters that are not optimized via gradient descent but are 414 directly computed from data, such as the running mean and variance 415 of Spatial Batch Normalization. 417 namescope = ModelHelper._NormalizeNamescope(namescope)
423 if p.GetNameScope().startswith(namescope)]
425 def GetAllParams(self, namescope=None):
429 self, unused_blob_in, blob_out, batch_size, db, db_type, **kwargs
431 """TensorProtosDBInput.""" 432 assert len(unused_blob_in) == 0, \
433 """You cannot pass reader to model_helper.TensorProtosDBInput. 434 Use model.net.TensorProtosDBInput instead to create the op.""" 436 return helpers.db_input.db_input(
437 self, blob_out, batch_size, db, db_type, **kwargs)
439 def GetDevices(self):
441 "Use data_parallel_model to run model on multiple GPUs." 445 """Catch-all for all other operators, mostly those without params.""" 446 if op_type.startswith(
'__'):
447 raise AttributeError(op_type)
449 if not core.IsOperator(op_type):
450 raise AttributeError(
451 'Method ' + op_type +
' is not a registered operator.' +
453 ','.join(workspace.C.nearby_opnames(op_type)) +
']' 455 if op_type
not in _known_working_ops:
457 raise AttributeError(
458 "Operator {} is not known to be safe".format(op_type))
460 logging.warning(
"You are creating an op that the ModelHelper " 461 "does not recognize: {}.".format(op_type))
462 return self.net.__getattr__(op_type)
465 return sorted(set(chain(
467 viewkeys(self.__dict__),
472 def ExtractPredictorNet(
478 disabled_inputs=
None,
481 Takes a model net for training and returns a net which can be 482 used for prediction. For example, all gradient operators and 483 input operators are removed. 484 @param net_proto protobuf of the net you want to process (net.Proto()) 485 @param input_blobs list/set of blob names that are the inputs of predictor 486 @param output_blobs list/set of blob names that are outputs of predictor 487 @param device optional device option that is assigned 488 @param renames dictionary of blob name to a new name (optional) 489 @param disabled_inputs optional set of blobs that are 'switched off'. This 490 will cause branches with those blobs as inputs to be removed 492 predict_net =
core.Net(net_proto.name +
"_predict")
493 predict_proto = predict_net.Proto()
495 orig_external_inputs = set(net_proto.external_input)
496 orig_external_outputs = set(net_proto.external_output)
497 input_blobs = {str(b)
for b
in input_blobs}
498 known_blobs = set(orig_external_inputs).union(input_blobs)
499 output_blobs = {str(b)
for b
in output_blobs}
500 external_inputs = set(input_blobs)
501 external_outputs = set(output_blobs)
506 if disabled_inputs
is not None:
507 known_blobs = known_blobs - set(disabled_inputs)
509 ops = list(net_proto.op)
513 first_op_with_input = min(
515 j
for j
in range(len(ops))
516 if input_blobs.intersection(ops[j].input)
and ops[j].type !=
521 raise Exception(
"No ops with input={}".format(input_blobs))
523 last_op_with_output = max(
525 j
for j
in range(len(ops))
526 if output_blobs.intersection(ops[j].output)
530 raise Exception(
"No ops with output={}".format(output_blobs))
536 if arg.name ==
"is_test" and arg.i == 0:
538 "An operator had is_test=0, did you try to extract a " +
539 "predictor from a train model (instead of test model)?" +
540 " Op was: {}".format(str(op))
543 def rename_list(proto_list):
545 new_list = proto_list[:]
546 for j, b
in enumerate(new_list):
548 new_list[j] = renames[b]
551 proto_list.extend(new_list)
555 for op
in ops[first_op_with_input:(last_op_with_output + 1)]:
556 if known_blobs.issuperset(op.input):
561 if op.type ==
'RecurrentNetwork':
563 if arg.name ==
'backward_step_net':
564 arg.ClearField(str(
'n'))
565 elif arg.name ==
'step_net':
566 for step_op
in arg.n.op:
567 rename_list(step_op.input)
568 rename_list(step_op.output)
569 if device
is not None:
570 step_op.device_option.device_type = device.device_type
571 step_op.device_option.cuda_gpu_id = device.cuda_gpu_id
573 rename_list(arg.n.external_input)
574 rename_list(arg.n.external_output)
577 external_inputs.update(
578 set(arg.n.external_input).intersection(
583 if device
is not None:
584 op.device_option.device_type = device.device_type
585 op.device_option.cuda_gpu_id = device.cuda_gpu_id
587 predict_proto.op.extend([op])
588 known_blobs.update(op.output)
589 external_inputs.update(
590 set(op.input).intersection(orig_external_inputs)
592 external_outputs.update(
593 set(op.output).intersection(orig_external_outputs)
598 "Op {} had unknown inputs: {}".format(
599 op.type, set(op.input).difference(known_blobs)
605 predict_proto.external_input.extend(external_inputs)
606 predict_proto.external_output.extend(external_outputs)
608 rename_list(predict_proto.external_input)
609 rename_list(predict_proto.external_output)
611 renamed_input_blobs = []
612 for b
in input_blobs:
614 renamed_input_blobs.append(renames[b])
616 renamed_input_blobs.append(b)
618 for op
in predict_proto.op:
619 rename_list(op.input)
620 rename_list(op.output)
622 return predict_net, list(
623 set(predict_proto.external_input) - set(renamed_input_blobs)
def AddParameter(self, param, tags=None)
def _infer_param_shape(self, param)
def create_param(self, param_name, shape, initializer, tags=None)
def GetOptimizationParamInfo(self, params=None)
def TensorProtosDBInput(self, unused_blob_in, blob_out, batch_size, db, db_type, kwargs)
def _update_param_info_deprecated(self)
def get_param_to_grad(self, params)
def GetComputedParams(self, namescope=None)
def __getattr__(self, op_type)
def get_param_info(self, param)
def _normalize_tags(self, tags)
def GetParams(self, namescope=None, top_scope=False)