3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
9 from caffe2.proto
import caffe2_pb2
10 from onnx.backend.base
import BackendRep, namedtupledict
13 def __init__(self, init_net, predict_net, workspace, uninitialized):
14 super(Caffe2Rep, self).__init__()
25 def _name_scope(self):
26 if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
27 return 'gpu_{}'.format(self.predict_net.device_option.cuda_gpu_id)
30 def run(self, inputs, **kwargs):
31 super(Caffe2Rep, self).run(inputs, **kwargs)
33 with core.DeviceScope(self.predict_net.device_option):
34 if isinstance(inputs, dict):
36 for key, value
in inputs.items():
37 workspace.FeedBlob(key, value)
38 elif isinstance(inputs, list)
or isinstance(inputs, tuple):
40 raise RuntimeError(
'Expected {} values for uninitialized ' 41 'graph inputs ({}), but got {}.'.format(
45 for i, value
in enumerate(inputs):
56 workspace.RunNet(self.init_net.name)
58 workspace.RunNet(self.predict_net.name)
59 output_values = [workspace.FetchBlob(name)
60 for name
in self.predict_net.external_output]
61 return namedtupledict(
'Outputs',
62 self.predict_net.external_output)(*output_values)