3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto
import caffe2_pb2
9 from caffe2.proto
import metanet_pb2
14 from builtins
import bytes
18 def get_predictor_exporter_helper(submodelNetName):
19 """ constracting stub for the PredictorExportMeta 20 Only used to construct names to subfields, 21 such as calling to predict_net_name 23 submodelNetName - name of the model 25 stub_net = core.Net(submodelNetName)
26 pred_meta = PredictorExportMeta(predict_net=stub_net,
36 class PredictorExportMeta(collections.namedtuple(
37 'PredictorExportMeta',
38 'predict_net, parameters, inputs, outputs, shapes, name, \ 39 extra_init_net, net_type, num_workers')):
41 Metadata to be used for serializaing a net. 43 parameters, inputs, outputs could be either BlobReference or blob's names 45 predict_net can be either core.Net, NetDef, PlanDef or object 47 Override the named tuple to provide optional name parameter. 48 name will be used to identify multiple prediction nets. 50 net_type is the type field in caffe2 NetDef - can be 'simple', 'dag', etc. 52 num_workers specifies for net type 'dag' how many threads should run ops 66 inputs = [str(i)
for i
in inputs]
67 outputs = [str(o)
for o
in outputs]
68 assert len(set(inputs)) == len(inputs), (
69 "All inputs to the predictor should be unique")
70 parameters = [str(p)
for p
in parameters]
71 assert set(parameters).isdisjoint(inputs), (
72 "Parameters and inputs are required to be disjoint. " 73 "Intersection: {}".format(set(parameters).intersection(inputs)))
74 assert set(parameters).isdisjoint(outputs), (
75 "Parameters and outputs are required to be disjoint. " 76 "Intersection: {}".format(set(parameters).intersection(outputs)))
80 predict_net = predict_net.Proto()
82 assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
83 return super(PredictorExportMeta, cls).__new__(
84 cls, predict_net, parameters, inputs, outputs, shapes, name,
85 extra_init_net, net_type, num_workers)
87 def inputs_name(self):
88 return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
91 def outputs_name(self):
92 return utils.get_comp_name(predictor_constants.OUTPUTS_BLOB_TYPE,
95 def parameters_name(self):
96 return utils.get_comp_name(predictor_constants.PARAMETERS_BLOB_TYPE,
99 def global_init_name(self):
100 return utils.get_comp_name(predictor_constants.GLOBAL_INIT_NET_TYPE,
103 def predict_init_name(self):
104 return utils.get_comp_name(predictor_constants.PREDICT_INIT_NET_TYPE,
107 def predict_net_name(self):
108 return utils.get_comp_name(predictor_constants.PREDICT_NET_TYPE,
111 def train_init_plan_name(self):
112 return utils.get_comp_name(predictor_constants.TRAIN_INIT_PLAN_TYPE,
115 def train_plan_name(self):
116 return utils.get_comp_name(predictor_constants.TRAIN_PLAN_TYPE,
120 def prepare_prediction_net(filename, db_type, device_option=None):
122 Helper function which loads all required blobs from the db 123 and returns prediction net ready to be used 125 metanet_def = load_from_db(filename, db_type, device_option)
127 global_init_net = utils.GetNet(
128 metanet_def, predictor_constants.GLOBAL_INIT_NET_TYPE)
129 workspace.RunNetOnce(global_init_net)
131 predict_init_net = utils.GetNet(
132 metanet_def, predictor_constants.PREDICT_INIT_NET_TYPE)
133 workspace.RunNetOnce(predict_init_net)
136 utils.GetNet(metanet_def, predictor_constants.PREDICT_NET_TYPE))
137 workspace.CreateNet(predict_net)
142 def _global_init_net(predictor_export_meta):
145 [predictor_constants.PREDICTOR_DBREADER],
146 predictor_export_meta.parameters)
147 net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
148 net.Proto().external_output.extend(predictor_export_meta.parameters)
151 utils.AddModelIdArg(predictor_export_meta, net.Proto())
155 def get_meta_net_def(predictor_export_meta, ws=None):
159 ws = ws
or workspace.C.Workspace.current
160 meta_net_def = metanet_pb2.MetaNetDef()
163 utils.AddNet(meta_net_def, predictor_export_meta.predict_init_name(),
164 utils.create_predict_init_net(ws, predictor_export_meta))
165 utils.AddNet(meta_net_def, predictor_export_meta.global_init_name(),
166 _global_init_net(predictor_export_meta))
167 utils.AddNet(meta_net_def, predictor_export_meta.predict_net_name(),
168 utils.create_predict_net(predictor_export_meta))
169 utils.AddBlobs(meta_net_def, predictor_export_meta.parameters_name(),
170 predictor_export_meta.parameters)
171 utils.AddBlobs(meta_net_def, predictor_export_meta.inputs_name(),
172 predictor_export_meta.inputs)
173 utils.AddBlobs(meta_net_def, predictor_export_meta.outputs_name(),
174 predictor_export_meta.outputs)
178 def set_model_info(meta_net_def, project_str, model_class_str, version):
179 assert isinstance(meta_net_def, metanet_pb2.MetaNetDef)
180 meta_net_def.modelInfo.project = project_str
181 meta_net_def.modelInfo.modelClass = model_class_str
182 meta_net_def.modelInfo.version = version
185 def save_to_db(db_type, db_destination, predictor_export_meta):
186 meta_net_def = get_meta_net_def(predictor_export_meta)
187 with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
189 predictor_constants.META_NET_DEF,
190 serde.serialize_protobuf_struct(meta_net_def)
193 blobs_to_save = [predictor_constants.META_NET_DEF] + \
194 predictor_export_meta.parameters
195 op = core.CreateOperator(
199 db=db_destination, db_type=db_type)
201 workspace.RunOperatorOnce(op)
204 def load_from_db(filename, db_type, device_option=None):
207 create_db = core.CreateOperator(
210 db=filename, db_type=db_type)
211 assert workspace.RunOperatorOnce(create_db), (
212 'Failed to create db {}'.format(filename))
215 load_meta_net_def = core.CreateOperator(
219 assert workspace.RunOperatorOnce(load_meta_net_def)
221 blob = workspace.FetchBlob(predictor_constants.META_NET_DEF)
222 meta_net_def = serde.deserialize_protobuf_struct(
223 blob
if isinstance(blob, bytes)
224 else str(blob).encode(
'utf-8'),
225 metanet_pb2.MetaNetDef)
227 if device_option
is None:
228 device_option = scope.CurrentDeviceScope()
230 if device_option
is not None:
232 for kv
in meta_net_def.nets:
235 op.device_option.CopyFrom(device_option)