Caffe2 - Python API
A deep learning, cross platform ML framework
mobile_exporter.py
1 ## @package mobile_exporter
2 # Module caffe2.python.mobile_exporter
3 
4 from __future__ import absolute_import
5 from __future__ import division
6 from __future__ import print_function
7 from __future__ import unicode_literals
8 from caffe2.python import core, utils
9 from caffe2.proto import caffe2_pb2
10 import numpy as np
11 
12 
13 def add_tensor(net, name, blob):
14  ''' Create an operator to store the tensor 'blob',
15  run the operator to put the blob to workspace.
16  uint8 is stored as an array of string with one element.
17  '''
18  kTypeNameMapper = {
19  np.dtype('float32'): "GivenTensorFill",
20  np.dtype('int32'): "GivenTensorIntFill",
21  np.dtype('int64'): "GivenTensorInt64Fill",
22  np.dtype('uint8'): "GivenTensorStringFill",
23  }
24 
25  shape = blob.shape
26  values = blob
27  # pass array of uint8 as a string to save storage
28  # storing uint8_t has a large overhead for now
29  if blob.dtype == np.dtype('uint8'):
30  shape = [1]
31  values = [str(blob.data)]
32 
33  op = core.CreateOperator(
34  kTypeNameMapper[blob.dtype],
35  [], [name],
36  arg=[
37  utils.MakeArgument("shape", shape),
38  utils.MakeArgument("values", values),
39  ]
40  )
41  net.op.extend([op])
42 
43 
44 def Export(workspace, net, params):
45  """Returns init_net and predict_net suitable for writing to disk
46  and loading into a Predictor"""
47  proto = net if isinstance(net, caffe2_pb2.NetDef) else net.Proto()
48  predict_net = caffe2_pb2.NetDef()
49  predict_net.CopyFrom(proto)
50  init_net = caffe2_pb2.NetDef()
51  # Populate the init_net.
52  ssa, blob_versions = core.get_ssa(net)
53  inputs = []
54  for versioned_inputs, _ in ssa:
55  inputs += [name for name, _ in versioned_inputs]
56 
57  input_blobs = [blob_name for blob_name, version in
58  blob_versions.items()
59  if version == 0 and blob_name not in params]
60  # Blobs that are never used as an input to another layer,
61  # i.e. strictly output blobs.
62  output_blobs = [blob_name for blob_name, version in
63  blob_versions.items()
64  if version != 0 and blob_name not in inputs]
65 
66  for blob_ref in params:
67  blob_name = str(blob_ref)
68  blob = workspace.FetchBlob(blob_name)
69  add_tensor(init_net, blob_name, blob)
70  # We have to make sure the blob exists in the namespace
71  # and we can do so with fake data. (Which is immediately overwritten
72  # by any typical usage)
73  for blob_name in input_blobs:
74  init_net.op.extend(
75  [
76  core.CreateOperator(
77  "GivenTensorFill", [], [blob_name],
78  arg=[
79  utils.MakeArgument("shape", [1, 1]),
80  utils.MakeArgument("values", [0.0])
81  ]
82  )
83  ]
84  )
85 
86  # Now we make input/output_blobs line up with what Predictor expects.
87  del predict_net.external_input[:]
88  predict_net.external_input.extend(input_blobs)
89  # For populating weights
90  predict_net.external_input.extend(proto.external_input)
91  # Ensure the output is also consistent with what we want
92  del predict_net.external_output[:]
93  predict_net.external_output.extend(output_blobs)
94  return init_net, predict_net