1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
7 from collections
import namedtuple
8 from six
import string_types
10 OpSchema = workspace.C.OpSchema
13 def namedtupledict(typename, field_names, *args, **kwargs):
14 field_names_map = {n: i
for i, n
in enumerate(field_names)}
16 kwargs.setdefault(
'rename',
True)
17 data = namedtuple(typename, field_names, *args, **kwargs)
19 def getitem(self, key):
20 if isinstance(key, string_types):
21 key = field_names_map[key]
22 return super(type(self), self).__getitem__(key)
24 data.__getitem__ = getitem
29 def __getattribute__(self, op_type):
30 def op_func(*inputs, **args):
31 ws = workspace.C.Workspace()
32 schema = OpSchema.get(op_type)
33 input_prefix =
'input_' 34 output_prefix =
'output_' 36 def get_name_list(prefix, num, max_num):
37 return [prefix + str(x)
for x
in range(min(num, max_num))]
39 input_names, output_names = [], []
40 input_names = get_name_list(
41 input_prefix, len(inputs), schema.max_input
45 num_input = len(input_names)
46 if num_input > schema.max_input
or num_input < \
47 schema.min_input
or not schema.num_inputs_allowed(num_input):
49 "Functional C2: Number of inputs not in \ 50 range: {} - {} or not allowed." 51 .format(schema.min_input, schema.max_input)
54 if 'num_output' in args:
55 num_output = args[
'num_output']
56 if num_output > schema.max_output
or \
57 num_output < schema.min_output
or \
58 not schema.num_outputs_allowed(num_output)
or \
59 not schema.num_inputs_outputs_allowed(num_input,
62 "Functional C2: Number of output \ 63 not in range: {} - {} or not allowed" 64 .format(schema.min_output, schema.max_output)
66 output_names = get_name_list(
67 output_prefix, num_output, schema.max_output
69 args.pop(
'num_output')
70 calculated = schema.CalculateOutput(num_input)
71 if not output_names
and calculated != -1:
72 output_names = get_name_list(
73 output_prefix, calculated, schema.max_output
77 max_output = schema.max_output
81 if schema.inf == max_output:
83 "For operators with max_output == inf,\ 84 user should pass num_output explicity." 86 output_names = get_name_list(
87 output_prefix, max_output, max_output
89 for i, input_blob
in enumerate(inputs):
90 ws.create_blob(input_names[i]).feed(input_blob)
92 op = core.CreateOperator(
93 op_type, input_names, output_names, **args
95 ws._run_operator(op.SerializeToString())
97 output_values = [ws.fetch_blob(x)
for x
in output_names]
98 return namedtupledict(
'output', output_names)(*output_values)