4 """Caffe2 Protobuf to ONNX converter 6 To run this, you will need to have Caffe2 installed as well. 9 from __future__
import absolute_import
10 from __future__
import division
11 from __future__
import print_function
12 from __future__
import unicode_literals
20 from caffe2.proto
import caffe2_legacy_pb2
22 from onnx
import (defs, checker, helper, numpy_helper, mapping,
23 ModelProto, GraphProto, NodeProto, AttributeProto, TensorProto, OperatorSetIdProto)
24 from onnx.helper
import make_tensor, make_tensor_value_info, make_attribute, make_model
32 logging.basicConfig(level=logging.INFO)
33 logger = logging.getLogger(__name__)
41 target_opset_version = 6
43 _renamed_operators = {
44 'SpatialBN':
'BatchNormalization',
48 'ConvTranspose1D':
'ConvTranspose',
49 'ConvTranspose2D':
'ConvTranspose',
50 'ConvTranspose3D':
'ConvTranspose',
51 'MaxPool1D':
'MaxPool',
52 'MaxPool2D':
'MaxPool',
53 'MaxPool3D':
'MaxPool',
54 'AveragePool1D':
'AveragePool',
55 'AveragePool2D':
'AveragePool',
56 'AveragePool3D':
'AveragePool',
60 _blacklist_caffe2_args = {
62 'cudnn_exhaustive_search': {0, 1},
66 _global_renamed_args = {
67 'kernels':
'kernel_shape',
70 _per_op_renamed_args = {
71 'Squeeze': {
'dims':
'axes'},
72 'Transpose': {
'axes':
'perm'},
75 _special_operators = {}
78 def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg):
85 name = cls._global_renamed_args.get(arg.name, arg.name)
90 elif arg.HasField(
'i'):
92 elif arg.HasField(
's'):
101 raise ValueError(
'Could not find data field in arg: {}'.format(arg))
107 return helper.make_attribute(name, value)
110 def caffe2_arg_to_onnx_attr(cls, op_def, arg):
114 def _common_caffe2_op_to_onnx_node(cls, op_def, shapes):
115 node_def = NodeProto()
116 node_def.name = op_def.name
118 node_def.op_type = cls._renamed_operators.get(op_def.type, op_def.type)
120 node_def.input.extend(op_def.input)
121 node_def.output.extend(op_def.output)
124 for arg
in op_def.arg])
125 node_def.attribute.extend(attrs)
130 def caffe2_op_to_onnx_node(cls, op_def, shapes):
131 if C.support_onnx_export(op_def.type):
132 shape_list = list(shapes.values())
133 node_strs, tensor_strs = C.export_to_onnx(op_def.SerializeToString(), shapes)
137 node.ParseFromString(s)
140 for s
in tensor_strs:
141 tensor = TensorProto()
142 tensor.ParseFromString(s)
143 const_tensors.append(tensor)
144 return nodes, const_tensors
149 nodes = translator(op_def, shapes)
151 if isinstance(nodes, tuple):
152 nodes, const_tensors = nodes
153 if not isinstance(nodes, collections.Iterable):
155 return nodes, const_tensors
158 def _all_names_in_net(net):
163 names.update(net.external_input)
164 names.update(net.external_output)
166 names.update(op.input)
167 names.update(op.output)
171 def _extract_value_info(tensor):
172 return make_tensor_value_info(
174 elem_type=tensor.data_type,
178 def caffe2_net_to_onnx_graph(cls,
182 if value_info
is None:
184 if not isinstance(value_info, dict):
185 raise ValueError(
'Please pass value_info as a ' 186 'name -> (type, shape) dictionary')
193 value_info.update({init.name: (init.data_type, init.dims)
194 for init
in initializer})
199 missing = (set(list(predict_net.external_input)) -
200 set(value_info.keys()))
202 raise RuntimeError(
'Could not find value info of inputs: {}'.format(
206 for name
in predict_net.external_input:
207 elem_type, shape = value_info[name]
208 inputs[name] = np.random.randn(*shape).astype(
209 mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
211 ws, outputs = c2_native_run_net(
216 for name
in predict_net.external_output:
217 output = outputs[name]
218 elem_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output.dtype]
220 value_info[name] = (elem_type, shape)
222 graph_def = GraphProto()
223 graph_def.name = predict_net.name
224 graph_def.initializer.extend(initializer)
226 graph_def.input.extend(
227 make_tensor_value_info(
229 elem_type=value_info[name][0],
230 shape=value_info[name][1])
231 for name
in predict_net.external_input)
236 for op
in predict_net.op:
238 for name
in itertools.chain(op.input, op.output):
239 blob = ws.FetchBlob(name)
240 if hasattr(blob,
'shape'):
241 shapes[name] = blob.shape
243 graph_def.node.extend(nodes)
244 graph_def.initializer.extend(const_tensors)
247 all_output = set(sum((list(node.output)
for node
in graph_def.node),
248 [init.name
for init
in graph_def.initializer]))
249 redundant_output = set(vi.name
for vi
in graph_def.output) - all_output
252 'There are graph output not produced by any node or initializer: {}' 253 '! Will drop them.'.format(
', '.join(redundant_output)))
254 graph_def.output.extend(
255 make_tensor_value_info(
257 elem_type=value_info[name][0],
258 shape=value_info[name][1])
259 for name
in predict_net.external_output
260 if name
in all_output)
262 checker.check_graph(graph_def)
266 def caffe2_init_net_to_initializer(cls, init_net):
268 for op
in init_net.op:
271 data_type, field_name = {
272 'GivenTensorFill': (TensorProto.FLOAT,
'floats'),
273 'GivenTensorInt64Fill': (TensorProto.INT64,
'ints'),
274 'GivenTensorIntFill': (TensorProto.INT32,
'ints'),
275 'GivenTensorBoolFill': (TensorProto.BOOL,
'ints'),
276 'GivenTensorStringFill': (TensorProto.STRING,
'strings'),
280 "Can not translate init_net with operator '{}' " 281 "to initializer".format(op.type)
283 raw = (data_type != TensorProto.STRING)
284 args = {a.name: a
for a
in op.arg}
285 vals = getattr(args[
'values'], field_name)
289 dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[data_type]).tobytes()
290 initializer.append(make_tensor(
293 dims=args[
'shape'].ints,
300 def _filter_fake_init(cls, init_net, value_info):
302 fake_inits = [op
for op
in init_net.op
303 if len(op.output) == 1
and op.output[0]
in value_info
and 304 re.match(
'GivenTensor.*Fill|ConstantFill', op.type)]
305 for fake_init
in fake_inits:
306 init_net.op.remove(fake_init)
311 def _ssa_rewrite(cls, net, init_net, value_info):
312 def ssa_name(name, version):
313 return '{}_{}'.format(name, version)
316 for op
in init_net.op:
317 assert re.match(
'GivenTensor.*Fill', op.type),
"type is {}, \n{}".format(op.type, op)
318 assert len(op.output) == 1
319 op.output[0] = ssa_name(op.output[0], 0)
320 init_net.external_input[:] = [ssa_name(name, 0)
321 for name
in init_net.external_input]
322 init_net.external_output[:] = [ssa_name(name, 0)
323 for name
in init_net.external_output]
325 ssa_value_info = {ssa_name(name, 0): value
326 for name, value
in value_info.items()}
328 value_info.update(ssa_value_info)
329 net.external_input[:] = [ssa_name(name, 0)
330 for name
in net.external_input]
331 ssa, blob_versions = caffe2_core.get_ssa(net)
332 assert len(net.op) == len(ssa)
333 for op, (versioned_inputs, versioned_outputs)
in zip(net.op, ssa):
334 op.input[:] = [ssa_name(name, version)
335 for name, version
in versioned_inputs]
336 op.output[:] = [ssa_name(name, version)
337 for name, version
in versioned_outputs]
338 net.external_output[:] = [ssa_name(name, blob_versions[name])
339 for name
in net.external_output]
342 def caffe2_net_to_onnx_model(cls, *args, **kwargs):
343 opset_id = OperatorSetIdProto()
347 opset_imports=[opset_id],
348 producer_name=
'onnx-caffe2',
350 checker.check_model(model)
354 caffe2_net_to_onnx_graph = Caffe2Frontend.caffe2_net_to_onnx_graph
355 caffe2_net_to_onnx_model = Caffe2Frontend.caffe2_net_to_onnx_model
356 caffe2_init_net_to_initializer = Caffe2Frontend.caffe2_init_net_to_initializer
def _common_caffe2_op_to_onnx_node(cls, op_def, shapes)
dictionary _special_operators
dictionary _blacklist_caffe2_args
def caffe2_op_to_onnx_node(cls, op_def, shapes)
def _filter_fake_init(cls, init_net, value_info)
def caffe2_net_to_onnx_graph(cls, predict_net, init_net=None, value_info=None)
def _all_names_in_net(net)
def caffe2_init_net_to_initializer(cls, init_net)
dictionary _per_op_renamed_args
def caffe2_arg_to_onnx_attr(cls, op_def, arg)
def _ssa_rewrite(cls, net, init_net, value_info)
def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg)
def _extract_value_info(tensor)