Caffe2 - Python API
A deep learning, cross platform ML framework
predictor_py_utils.py
1 ## @package predictor_py_utils
2 # Module caffe2.python.predictor.predictor_py_utils
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, scope
9 
10 
11 def create_predict_net(predictor_export_meta):
12  """
13  Return the input prediction net.
14  """
15  # Construct a new net to clear the existing settings.
16  net = core.Net(predictor_export_meta.predict_net.name or "predict")
17  net.Proto().op.extend(predictor_export_meta.predict_net.op)
18  net.Proto().external_input.extend(
19  predictor_export_meta.inputs + predictor_export_meta.parameters)
20  net.Proto().external_output.extend(predictor_export_meta.outputs)
21  net.Proto().arg.extend(predictor_export_meta.predict_net.arg)
22  if predictor_export_meta.net_type is not None:
23  net.Proto().type = predictor_export_meta.net_type
24  if predictor_export_meta.num_workers is not None:
25  net.Proto().num_workers = predictor_export_meta.num_workers
26  return net.Proto()
27 
28 
29 def create_predict_init_net(ws, predictor_export_meta):
30  """
31  Return an initialization net that zero-fill all the input and
32  output blobs, using the shapes from the provided workspace. This is
33  necessary as there is no shape inference functionality in Caffe2.
34  """
35  net = core.Net("predict-init")
36 
37  def zero_fill(blob):
38  shape = predictor_export_meta.shapes.get(blob)
39  if shape is None:
40  if blob not in ws.blobs:
41  raise Exception(
42  "{} not in workspace but needed for shape: {}".format(
43  blob, ws.blobs))
44 
45  shape = ws.blobs[blob].fetch().shape
46 
47  # Explicitly null-out the scope so users (e.g. PredictorGPU)
48  # can control (at a Net-global level) the DeviceOption of
49  # these filling operators.
50  with scope.EmptyDeviceScope():
51  net.ConstantFill([], blob, shape=shape, value=0.0)
52 
53  external_blobs = predictor_export_meta.inputs + \
54  predictor_export_meta.outputs
55  for blob in external_blobs:
56  zero_fill(blob)
57 
58  net.Proto().external_input.extend(external_blobs)
59  if predictor_export_meta.extra_init_net:
60  net.AppendNet(predictor_export_meta.extra_init_net)
61 
62  # Add the model_id in the predict_net to the init_net
63  AddModelIdArg(predictor_export_meta, net.Proto())
64 
65  return net.Proto()
66 
67 
68 def get_comp_name(string, name):
69  if name:
70  return string + '_' + name
71  return string
72 
73 
74 def _ProtoMapGet(field, key):
75  '''
76  Given the key, get the value of the repeated field.
77  Helper function used by protobuf since it doesn't have map construct
78  '''
79  for v in field:
80  if (v.key == key):
81  return v.value
82  return None
83 
84 
85 def GetPlan(meta_net_def, key):
86  return _ProtoMapGet(meta_net_def.plans, key)
87 
88 
89 def GetPlanOriginal(meta_net_def, key):
90  return _ProtoMapGet(meta_net_def.plans, key)
91 
92 
93 def GetBlobs(meta_net_def, key):
94  blobs = _ProtoMapGet(meta_net_def.blobs, key)
95  if blobs is None:
96  return []
97  return blobs
98 
99 
100 def GetNet(meta_net_def, key):
101  return _ProtoMapGet(meta_net_def.nets, key)
102 
103 
104 def GetNetOriginal(meta_net_def, key):
105  return _ProtoMapGet(meta_net_def.nets, key)
106 
107 
108 def GetApplicationSpecificInfo(meta_net_def, key):
109  return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
110 
111 
112 def AddBlobs(meta_net_def, blob_name, blob_def):
113  blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
114  if blobs is None:
115  blobs = meta_net_def.blobs.add()
116  blobs.key = blob_name
117  blobs = blobs.value
118  for blob in blob_def:
119  blobs.append(blob)
120 
121 
122 def AddPlan(meta_net_def, plan_name, plan_def):
123  meta_net_def.plans.add(key=plan_name, value=plan_def)
124 
125 
126 def AddNet(meta_net_def, net_name, net_def):
127  meta_net_def.nets.add(key=net_name, value=net_def)
128 
129 
130 def GetArgumentByName(net_def, arg_name):
131  for arg in net_def.arg:
132  if arg.name == arg_name:
133  return arg
134  return None
135 
136 
137 def AddModelIdArg(meta_net_def, net_def):
138  """Takes the model_id from the predict_net of meta_net_def (if it is
139  populated) and adds it to the net_def passed in. This is intended to be
140  called on init_nets, as their model_id is not populated by default, but
141  should be the same as that of the predict_net
142  """
143  # Get model_id from the predict_net, assuming it's an integer
144  model_id = GetArgumentByName(meta_net_def.predict_net, "model_id")
145  if model_id is None:
146  return
147  model_id = model_id.i
148 
149  # If there's another model_id on the net, replace it with the new one
150  old_id = GetArgumentByName(net_def, "model_id")
151  if old_id is not None:
152  old_id.i = model_id
153  return
154 
155  # Add as an integer argument, this is also assumed above
156  arg = net_def.arg.add()
157  arg.name = "model_id"
158  arg.i = model_id