Caffe2 - Python API
A deep learning, cross platform ML framework
predictor_exporter.py
1 ## @package predictor_exporter
2 # Module caffe2.python.predictor.predictor_exporter
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto import caffe2_pb2
9 from caffe2.proto import metanet_pb2
10 from caffe2.python import workspace, core, scope
11 from caffe2.python.predictor_constants import predictor_constants
12 import caffe2.python.predictor.serde as serde
14 from builtins import bytes
15 import collections
16 
17 
18 def get_predictor_exporter_helper(submodelNetName):
19  """ constracting stub for the PredictorExportMeta
20  Only used to construct names to subfields,
21  such as calling to predict_net_name
22  Args:
23  submodelNetName - name of the model
24  """
25  stub_net = core.Net(submodelNetName)
26  pred_meta = PredictorExportMeta(predict_net=stub_net,
27  parameters=[],
28  inputs=[],
29  outputs=[],
30  shapes=None,
31  name=submodelNetName,
32  extra_init_net=None)
33  return pred_meta
34 
35 
36 class PredictorExportMeta(collections.namedtuple(
37  'PredictorExportMeta',
38  'predict_net, parameters, inputs, outputs, shapes, name, \
39  extra_init_net, net_type, num_workers')):
40  """
41  Metadata to be used for serializaing a net.
42 
43  parameters, inputs, outputs could be either BlobReference or blob's names
44 
45  predict_net can be either core.Net, NetDef, PlanDef or object
46 
47  Override the named tuple to provide optional name parameter.
48  name will be used to identify multiple prediction nets.
49 
50  net_type is the type field in caffe2 NetDef - can be 'simple', 'dag', etc.
51 
52  num_workers specifies for net type 'dag' how many threads should run ops
53  """
54  def __new__(
55  cls,
56  predict_net,
57  parameters,
58  inputs,
59  outputs,
60  shapes=None,
61  name="",
62  extra_init_net=None,
63  net_type=None,
64  num_workers=None,
65  ):
66  inputs = [str(i) for i in inputs]
67  outputs = [str(o) for o in outputs]
68  assert len(set(inputs)) == len(inputs), (
69  "All inputs to the predictor should be unique")
70  parameters = [str(p) for p in parameters]
71  assert set(parameters).isdisjoint(inputs), (
72  "Parameters and inputs are required to be disjoint. "
73  "Intersection: {}".format(set(parameters).intersection(inputs)))
74  assert set(parameters).isdisjoint(outputs), (
75  "Parameters and outputs are required to be disjoint. "
76  "Intersection: {}".format(set(parameters).intersection(outputs)))
77  shapes = shapes or {}
78 
79  if isinstance(predict_net, (core.Net, core.Plan)):
80  predict_net = predict_net.Proto()
81 
82  assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
83  return super(PredictorExportMeta, cls).__new__(
84  cls, predict_net, parameters, inputs, outputs, shapes, name,
85  extra_init_net, net_type, num_workers)
86 
87  def inputs_name(self):
88  return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
89  self.name)
90 
91  def outputs_name(self):
92  return utils.get_comp_name(predictor_constants.OUTPUTS_BLOB_TYPE,
93  self.name)
94 
95  def parameters_name(self):
96  return utils.get_comp_name(predictor_constants.PARAMETERS_BLOB_TYPE,
97  self.name)
98 
99  def global_init_name(self):
100  return utils.get_comp_name(predictor_constants.GLOBAL_INIT_NET_TYPE,
101  self.name)
102 
103  def predict_init_name(self):
104  return utils.get_comp_name(predictor_constants.PREDICT_INIT_NET_TYPE,
105  self.name)
106 
107  def predict_net_name(self):
108  return utils.get_comp_name(predictor_constants.PREDICT_NET_TYPE,
109  self.name)
110 
111  def train_init_plan_name(self):
112  return utils.get_comp_name(predictor_constants.TRAIN_INIT_PLAN_TYPE,
113  self.name)
114 
115  def train_plan_name(self):
116  return utils.get_comp_name(predictor_constants.TRAIN_PLAN_TYPE,
117  self.name)
118 
119 
120 def prepare_prediction_net(filename, db_type, device_option=None):
121  '''
122  Helper function which loads all required blobs from the db
123  and returns prediction net ready to be used
124  '''
125  metanet_def = load_from_db(filename, db_type, device_option)
126 
127  global_init_net = utils.GetNet(
128  metanet_def, predictor_constants.GLOBAL_INIT_NET_TYPE)
129  workspace.RunNetOnce(global_init_net)
130 
131  predict_init_net = utils.GetNet(
132  metanet_def, predictor_constants.PREDICT_INIT_NET_TYPE)
133  workspace.RunNetOnce(predict_init_net)
134 
135  predict_net = core.Net(
136  utils.GetNet(metanet_def, predictor_constants.PREDICT_NET_TYPE))
137  workspace.CreateNet(predict_net)
138 
139  return predict_net
140 
141 
142 def _global_init_net(predictor_export_meta):
143  net = core.Net("global-init")
144  net.Load(
145  [predictor_constants.PREDICTOR_DBREADER],
146  predictor_export_meta.parameters)
147  net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
148  net.Proto().external_output.extend(predictor_export_meta.parameters)
149 
150  # Add the model_id in the predict_net to the global_init_net
151  utils.AddModelIdArg(predictor_export_meta, net.Proto())
152  return net.Proto()
153 
154 
155 def get_meta_net_def(predictor_export_meta, ws=None):
156  """
157  """
158 
159  ws = ws or workspace.C.Workspace.current
160  meta_net_def = metanet_pb2.MetaNetDef()
161 
162  # Predict net is the core network that we use.
163  utils.AddNet(meta_net_def, predictor_export_meta.predict_init_name(),
164  utils.create_predict_init_net(ws, predictor_export_meta))
165  utils.AddNet(meta_net_def, predictor_export_meta.global_init_name(),
166  _global_init_net(predictor_export_meta))
167  utils.AddNet(meta_net_def, predictor_export_meta.predict_net_name(),
168  utils.create_predict_net(predictor_export_meta))
169  utils.AddBlobs(meta_net_def, predictor_export_meta.parameters_name(),
170  predictor_export_meta.parameters)
171  utils.AddBlobs(meta_net_def, predictor_export_meta.inputs_name(),
172  predictor_export_meta.inputs)
173  utils.AddBlobs(meta_net_def, predictor_export_meta.outputs_name(),
174  predictor_export_meta.outputs)
175  return meta_net_def
176 
177 
178 def set_model_info(meta_net_def, project_str, model_class_str, version):
179  assert isinstance(meta_net_def, metanet_pb2.MetaNetDef)
180  meta_net_def.modelInfo.project = project_str
181  meta_net_def.modelInfo.modelClass = model_class_str
182  meta_net_def.modelInfo.version = version
183 
184 
185 def save_to_db(db_type, db_destination, predictor_export_meta):
186  meta_net_def = get_meta_net_def(predictor_export_meta)
187  with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
188  workspace.FeedBlob(
189  predictor_constants.META_NET_DEF,
190  serde.serialize_protobuf_struct(meta_net_def)
191  )
192 
193  blobs_to_save = [predictor_constants.META_NET_DEF] + \
194  predictor_export_meta.parameters
195  op = core.CreateOperator(
196  "Save",
197  blobs_to_save, [],
198  absolute_path=True,
199  db=db_destination, db_type=db_type)
200 
201  workspace.RunOperatorOnce(op)
202 
203 
204 def load_from_db(filename, db_type, device_option=None):
205  # global_init_net in meta_net_def will load parameters from
206  # predictor_constants.PREDICTOR_DBREADER
207  create_db = core.CreateOperator(
208  'CreateDB', [],
209  [core.BlobReference(predictor_constants.PREDICTOR_DBREADER)],
210  db=filename, db_type=db_type)
211  assert workspace.RunOperatorOnce(create_db), (
212  'Failed to create db {}'.format(filename))
213 
214  # predictor_constants.META_NET_DEF is always stored before the parameters
215  load_meta_net_def = core.CreateOperator(
216  'Load',
217  [core.BlobReference(predictor_constants.PREDICTOR_DBREADER)],
218  [core.BlobReference(predictor_constants.META_NET_DEF)])
219  assert workspace.RunOperatorOnce(load_meta_net_def)
220 
221  blob = workspace.FetchBlob(predictor_constants.META_NET_DEF)
222  meta_net_def = serde.deserialize_protobuf_struct(
223  blob if isinstance(blob, bytes)
224  else str(blob).encode('utf-8'),
225  metanet_pb2.MetaNetDef)
226 
227  if device_option is None:
228  device_option = scope.CurrentDeviceScope()
229 
230  if device_option is not None:
231  # Set the device options of all loaded blobs
232  for kv in meta_net_def.nets:
233  net = kv.value
234  for op in net.op:
235  op.device_option.CopyFrom(device_option)
236 
237  return meta_net_def