3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 from past.builtins
import basestring
33 'arg_scope': arg_scope,
35 'packed_fc': packed_fc,
36 'fc_decomp': fc_decomp,
37 'fc_sparse': fc_sparse,
41 'average_pool': average_pool,
42 'max_pool_with_index' : max_pool_with_index,
45 'instance_norm': instance_norm,
46 'spatial_bn': spatial_bn,
51 'depth_concat': depth_concat,
53 'transpose': transpose,
58 'conv_transpose': conv_transpose,
59 'group_conv': group_conv,
60 'group_conv_deprecated': group_conv_deprecated,
61 'image_input': image_input,
62 'video_input': video_input,
63 'add_weight_decay': add_weight_decay,
64 'elementwise_linear': elementwise_linear,
65 'layer_norm': layer_norm,
66 'batch_mat_mul' : batch_mat_mul,
69 'db_input' : db_input,
72 def __init__(self, wrapped):
75 def __getattr__(self, helper_name):
78 "Helper function {} not " 79 "registered.".format(helper_name)
82 def scope_wrapper(*args, **kwargs):
84 if helper_name !=
'arg_scope':
85 if len(args) > 0
and isinstance(args[0], ModelHelper):
87 elif 'model' in kwargs:
88 model = kwargs[
'model']
91 "The first input of helper function should be model. " \
92 "Or you can provide it in kwargs as model=<your_model>.")
93 new_kwargs = copy.deepcopy(model.arg_scope)
95 var_names, _, varkw, _= inspect.getargspec(func)
99 var_name: new_kwargs[var_name]
100 for var_name
in var_names
if var_name
in new_kwargs
103 cur_scope = get_current_scope()
104 new_kwargs.update(cur_scope.get(helper_name, {}))
105 new_kwargs.update(kwargs)
106 return func(*args, **new_kwargs)
108 scope_wrapper.__name__ = helper_name
111 def Register(self, helper):
112 name = helper.__name__
114 raise AttributeError(
115 "Helper {} already exists. Please change your " 116 "helper name.".format(name)
120 def has_helper(self, helper_or_helper_name):
122 helper_or_helper_name
123 if isinstance(helper_or_helper_name, basestring)
else 124 helper_or_helper_name.__name__