3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 def create_predict_net(predictor_export_meta):
13 Return the input prediction net. 16 net = core.Net(predictor_export_meta.predict_net.name
or "predict")
17 net.Proto().op.extend(predictor_export_meta.predict_net.op)
18 net.Proto().external_input.extend(
19 predictor_export_meta.inputs + predictor_export_meta.parameters)
20 net.Proto().external_output.extend(predictor_export_meta.outputs)
21 net.Proto().arg.extend(predictor_export_meta.predict_net.arg)
22 if predictor_export_meta.net_type
is not None:
23 net.Proto().type = predictor_export_meta.net_type
24 if predictor_export_meta.num_workers
is not None:
25 net.Proto().num_workers = predictor_export_meta.num_workers
29 def create_predict_init_net(ws, predictor_export_meta):
31 Return an initialization net that zero-fill all the input and 32 output blobs, using the shapes from the provided workspace. This is 33 necessary as there is no shape inference functionality in Caffe2. 35 net = core.Net(
"predict-init")
38 shape = predictor_export_meta.shapes.get(blob)
40 if blob
not in ws.blobs:
42 "{} not in workspace but needed for shape: {}".format(
45 shape = ws.blobs[blob].fetch().shape
50 with scope.EmptyDeviceScope():
51 net.ConstantFill([], blob, shape=shape, value=0.0)
53 external_blobs = predictor_export_meta.inputs + \
54 predictor_export_meta.outputs
55 for blob
in external_blobs:
58 net.Proto().external_input.extend(external_blobs)
59 if predictor_export_meta.extra_init_net:
60 net.AppendNet(predictor_export_meta.extra_init_net)
63 AddModelIdArg(predictor_export_meta, net.Proto())
68 def get_comp_name(string, name):
70 return string +
'_' + name
74 def _ProtoMapGet(field, key):
76 Given the key, get the value of the repeated field. 77 Helper function used by protobuf since it doesn't have map construct 85 def GetPlan(meta_net_def, key):
86 return _ProtoMapGet(meta_net_def.plans, key)
89 def GetPlanOriginal(meta_net_def, key):
90 return _ProtoMapGet(meta_net_def.plans, key)
93 def GetBlobs(meta_net_def, key):
94 blobs = _ProtoMapGet(meta_net_def.blobs, key)
100 def GetNet(meta_net_def, key):
101 return _ProtoMapGet(meta_net_def.nets, key)
104 def GetNetOriginal(meta_net_def, key):
105 return _ProtoMapGet(meta_net_def.nets, key)
108 def GetApplicationSpecificInfo(meta_net_def, key):
109 return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
112 def AddBlobs(meta_net_def, blob_name, blob_def):
113 blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
115 blobs = meta_net_def.blobs.add()
116 blobs.key = blob_name
118 for blob
in blob_def:
122 def AddPlan(meta_net_def, plan_name, plan_def):
123 meta_net_def.plans.add(key=plan_name, value=plan_def)
126 def AddNet(meta_net_def, net_name, net_def):
127 meta_net_def.nets.add(key=net_name, value=net_def)
130 def GetArgumentByName(net_def, arg_name):
131 for arg
in net_def.arg:
132 if arg.name == arg_name:
137 def AddModelIdArg(meta_net_def, net_def):
138 """Takes the model_id from the predict_net of meta_net_def (if it is 139 populated) and adds it to the net_def passed in. This is intended to be 140 called on init_nets, as their model_id is not populated by default, but 141 should be the same as that of the predict_net 144 model_id = GetArgumentByName(meta_net_def.predict_net,
"model_id")
147 model_id = model_id.i
150 old_id = GetArgumentByName(net_def,
"model_id")
151 if old_id
is not None:
156 arg = net_def.arg.add()
157 arg.name =
"model_id"