Caffe2 - Python API
A deep learning, cross platform ML framework
backend_rep.py
1 ## @package onnx
2 # Module caffe2.python.onnx.backend_rep
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, workspace
9 from caffe2.proto import caffe2_pb2
10 from onnx.backend.base import BackendRep, namedtupledict
11 
12 class Caffe2Rep(BackendRep):
13  def __init__(self, init_net, predict_net, workspace, uninitialized):
14  super(Caffe2Rep, self).__init__()
15  self.init_net = init_net
16  self.predict_net = predict_net
17  self.workspace = workspace
18  # The list of uninitialized external_inputs in workspace, we need this to
19  # pair the name with given sequence inputs.
20  self.uninitialized = uninitialized
21  self.nets_created = False
22  self.ran_init_net = False
23 
24  @property
25  def _name_scope(self):
26  if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
27  return 'gpu_{}'.format(self.predict_net.device_option.cuda_gpu_id)
28  return ''
29 
30  def run(self, inputs, **kwargs):
31  super(Caffe2Rep, self).run(inputs, **kwargs)
32  with self.workspace:
33  with core.DeviceScope(self.predict_net.device_option):
34  if isinstance(inputs, dict):
35  with core.NameScope(self._name_scope):
36  for key, value in inputs.items():
37  workspace.FeedBlob(key, value)
38  elif isinstance(inputs, list) or isinstance(inputs, tuple):
39  if len(self.uninitialized) != len(inputs):
40  raise RuntimeError('Expected {} values for uninitialized '
41  'graph inputs ({}), but got {}.'.format(
42  len(self.uninitialized),
43  ', '.join(self.uninitialized),
44  len(inputs)))
45  for i, value in enumerate(inputs):
46  # namescope already baked into protobuf
47  workspace.FeedBlob(self.uninitialized[i], value)
48  else:
49  # single input
50  workspace.FeedBlob(self.uninitialized[0], inputs)
51  if not self.nets_created:
52  workspace.CreateNet(self.init_net)
53  workspace.CreateNet(self.predict_net)
54  self.nets_created = True
55  if not self.ran_init_net:
56  workspace.RunNet(self.init_net.name)
57  self.ran_init_net = True
58  workspace.RunNet(self.predict_net.name)
59  output_values = [workspace.FetchBlob(name)
60  for name in self.predict_net.external_output]
61  return namedtupledict('Outputs',
62  self.predict_net.external_output)(*output_values)