Caffe2 - Python API
A deep learning, cross platform ML framework
utils.py
1 ## @package utils
2 # Module caffe2.python.utils
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto import caffe2_pb2
9 from future.utils import viewitems
10 from google.protobuf.message import DecodeError, Message
11 from google.protobuf import text_format
12 
13 import sys
14 import copy
15 import collections
16 import functools
17 import numpy as np
18 from six import integer_types, binary_type, text_type
19 
20 
21 def OpAlmostEqual(op_a, op_b, ignore_fields=None):
22  '''
23  Two ops are identical except for each field in the `ignore_fields`.
24  '''
25  ignore_fields = ignore_fields or []
26  if not isinstance(ignore_fields, list):
27  ignore_fields = [ignore_fields]
28 
29  assert all(isinstance(f, text_type) for f in ignore_fields), (
30  'Expect each field is text type, but got {}'.format(ignore_fields))
31 
32  def clean_op(op):
33  op = copy.deepcopy(op)
34  for field in ignore_fields:
35  if op.HasField(field):
36  op.ClearField(field)
37  return op
38 
39  op_a = clean_op(op_a)
40  op_b = clean_op(op_b)
41  return op_a == op_b
42 
43 
44 def CaffeBlobToNumpyArray(blob):
45  if (blob.num != 0):
46  # old style caffe blob.
47  return (np.asarray(blob.data, dtype=np.float32)
48  .reshape(blob.num, blob.channels, blob.height, blob.width))
49  else:
50  # new style caffe blob.
51  return (np.asarray(blob.data, dtype=np.float32)
52  .reshape(blob.shape.dim))
53 
54 
55 def Caffe2TensorToNumpyArray(tensor):
56  if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
57  return np.asarray(
58  tensor.float_data, dtype=np.float32).reshape(tensor.dims)
59  elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
60  return np.asarray(
61  tensor.double_data, dtype=np.float64).reshape(tensor.dims)
62  elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
63  return np.asarray(
64  tensor.int32_data, dtype=np.int).reshape(tensor.dims) # pb.INT32=>np.int use int32_data
65  elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
66  return np.asarray(
67  tensor.int32_data, dtype=np.int16).reshape(tensor.dims) # pb.INT16=>np.int16 use int32_data
68  elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
69  return np.asarray(
70  tensor.int32_data, dtype=np.uint16).reshape(tensor.dims) # pb.UINT16=>np.uint16 use int32_data
71  elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
72  return np.asarray(
73  tensor.int32_data, dtype=np.int8).reshape(tensor.dims) # pb.INT8=>np.int8 use int32_data
74  elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
75  return np.asarray(
76  tensor.int32_data, dtype=np.uint8).reshape(tensor.dims) # pb.UINT8=>np.uint8 use int32_data
77  else:
78  # TODO: complete the data type: bool, float16, byte, int64, string
79  raise RuntimeError(
80  "Tensor data type not supported yet: " + str(tensor.data_type))
81 
82 
83 def NumpyArrayToCaffe2Tensor(arr, name=None):
84  tensor = caffe2_pb2.TensorProto()
85  tensor.dims.extend(arr.shape)
86  if name:
87  tensor.name = name
88  if arr.dtype == np.float32:
89  tensor.data_type = caffe2_pb2.TensorProto.FLOAT
90  tensor.float_data.extend(list(arr.flatten().astype(float)))
91  elif arr.dtype == np.float64:
92  tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
93  tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
94  elif arr.dtype == np.int or arr.dtype == np.int32:
95  tensor.data_type = caffe2_pb2.TensorProto.INT32
96  tensor.int32_data.extend(arr.flatten().astype(np.int).tolist())
97  elif arr.dtype == np.int16:
98  tensor.data_type = caffe2_pb2.TensorProto.INT16
99  tensor.int32_data.extend(list(arr.flatten().astype(np.int16))) # np.int16=>pb.INT16 use int32_data
100  elif arr.dtype == np.uint16:
101  tensor.data_type = caffe2_pb2.TensorProto.UINT16
102  tensor.int32_data.extend(list(arr.flatten().astype(np.uint16))) # np.uint16=>pb.UNIT16 use int32_data
103  elif arr.dtype == np.int8:
104  tensor.data_type = caffe2_pb2.TensorProto.INT8
105  tensor.int32_data.extend(list(arr.flatten().astype(np.int8))) # np.int8=>pb.INT8 use int32_data
106  elif arr.dtype == np.uint8:
107  tensor.data_type = caffe2_pb2.TensorProto.UINT8
108  tensor.int32_data.extend(list(arr.flatten().astype(np.uint8))) # np.uint8=>pb.UNIT8 use int32_data
109  else:
110  # TODO: complete the data type: bool, float16, byte, int64, string
111  raise RuntimeError(
112  "Numpy data type not supported yet: " + str(arr.dtype))
113  return tensor
114 
115 
116 def MakeArgument(key, value):
117  """Makes an argument based on the value type."""
118  argument = caffe2_pb2.Argument()
119  argument.name = key
120  iterable = isinstance(value, collections.Iterable)
121 
122  # Fast tracking common use case where a float32 array of tensor parameters
123  # needs to be serialized. The entire array is guaranteed to have the same
124  # dtype, so no per-element checking necessary and no need to convert each
125  # element separately.
126  if isinstance(value, np.ndarray) and value.dtype.type is np.float32:
127  argument.floats.extend(value.flatten().tolist())
128  return argument
129 
130  if isinstance(value, np.ndarray):
131  value = value.flatten().tolist()
132  elif isinstance(value, np.generic):
133  # convert numpy scalar to native python type
134  value = np.asscalar(value)
135 
136  if type(value) is float:
137  argument.f = value
138  elif type(value) in integer_types or type(value) is bool:
139  # We make a relaxation that a boolean variable will also be stored as
140  # int.
141  argument.i = value
142  elif isinstance(value, binary_type):
143  argument.s = value
144  elif isinstance(value, text_type):
145  argument.s = value.encode('utf-8')
146  elif isinstance(value, caffe2_pb2.NetDef):
147  argument.n.CopyFrom(value)
148  elif isinstance(value, Message):
149  argument.s = value.SerializeToString()
150  elif iterable and all(type(v) in [float, np.float_] for v in value):
151  argument.floats.extend(
152  v.item() if type(v) is np.float_ else v for v in value
153  )
154  elif iterable and all(
155  type(v) in integer_types or type(v) in [bool, np.int_] for v in value
156  ):
157  argument.ints.extend(
158  v.item() if type(v) is np.int_ else v for v in value
159  )
160  elif iterable and all(
161  isinstance(v, binary_type) or isinstance(v, text_type) for v in value
162  ):
163  argument.strings.extend(
164  v.encode('utf-8') if isinstance(v, text_type) else v
165  for v in value
166  )
167  elif iterable and all(isinstance(v, caffe2_pb2.NetDef) for v in value):
168  argument.nets.extend(value)
169  elif iterable and all(isinstance(v, Message) for v in value):
170  argument.strings.extend(v.SerializeToString() for v in value)
171  else:
172  if iterable:
173  raise ValueError(
174  "Unknown iterable argument type: key={} value={}, value "
175  "type={}[{}]".format(
176  key, value, type(value), set(type(v) for v in value)
177  )
178  )
179  else:
180  raise ValueError(
181  "Unknown argument type: key={} value={}, value type={}".format(
182  key, value, type(value)
183  )
184  )
185  return argument
186 
187 
188 def TryReadProtoWithClass(cls, s):
189  """Reads a protobuffer with the given proto class.
190 
191  Inputs:
192  cls: a protobuffer class.
193  s: a string of either binary or text protobuffer content.
194 
195  Outputs:
196  proto: the protobuffer of cls
197 
198  Throws:
199  google.protobuf.message.DecodeError: if we cannot decode the message.
200  """
201  obj = cls()
202  try:
203  text_format.Parse(s, obj)
204  return obj
205  except text_format.ParseError:
206  obj.ParseFromString(s)
207  return obj
208 
209 
210 def GetContentFromProto(obj, function_map):
211  """Gets a specific field from a protocol buffer that matches the given class
212  """
213  for cls, func in viewitems(function_map):
214  if type(obj) is cls:
215  return func(obj)
216 
217 
218 def GetContentFromProtoString(s, function_map):
219  for cls, func in viewitems(function_map):
220  try:
221  obj = TryReadProtoWithClass(cls, s)
222  return func(obj)
223  except DecodeError:
224  continue
225  else:
226  raise DecodeError("Cannot find a fit protobuffer class.")
227 
228 
229 def ConvertProtoToBinary(proto_class, filename, out_filename):
230  """Convert a text file of the given protobuf class to binary."""
231  proto = TryReadProtoWithClass(proto_class, open(filename).read())
232  with open(out_filename, 'w') as fid:
233  fid.write(proto.SerializeToString())
234 
235 
236 def GetGPUMemoryUsageStats():
237  """Get GPU memory usage stats from CUDAContext. This requires flag
238  --caffe2_gpu_memory_tracking to be enabled"""
239  from caffe2.python import workspace, core
240  workspace.RunOperatorOnce(
241  core.CreateOperator(
242  "GetGPUMemoryUsage",
243  [],
244  ["____mem____"],
245  device_option=core.DeviceOption(caffe2_pb2.CUDA, 0),
246  ),
247  )
248  b = workspace.FetchBlob("____mem____")
249  return {
250  'total_by_gpu': b[0, :],
251  'max_by_gpu': b[1, :],
252  'total': np.sum(b[0, :]),
253  'max_total': np.sum(b[1, :])
254  }
255 
256 
257 def ResetBlobs(blobs):
258  from caffe2.python import workspace, core
259  workspace.RunOperatorOnce(
260  core.CreateOperator(
261  "Free",
262  list(blobs),
263  list(blobs),
264  device_option=core.DeviceOption(caffe2_pb2.CPU),
265  ),
266  )
267 
268 
269 class DebugMode(object):
270  '''
271  This class allows to drop you into an interactive debugger
272  if there is an unhandled exception in your python script
273 
274  Example of usage:
275 
276  def main():
277  # your code here
278  pass
279 
280  if __name__ == '__main__':
281  from caffe2.python.utils import DebugMode
282  DebugMode.run(main)
283  '''
284 
285  @classmethod
286  def run(cls, func):
287  try:
288  return func()
289  except KeyboardInterrupt:
290  raise
291  except Exception:
292  import pdb
293 
294  print(
295  'Entering interactive debugger. Type "bt" to print '
296  'the full stacktrace. Type "help" to see command listing.')
297  print(sys.exc_info()[1])
298  print
299 
300  pdb.post_mortem()
301  sys.exit(1)
302  raise
303 
304 
305 def raiseIfNotEqual(a, b, msg):
306  if a != b:
307  raise Exception("{}. {} != {}".format(msg, a, b))
308 
309 
310 def debug(f):
311  '''
312  Use this method to decorate your function with DebugMode's functionality
313 
314  Example:
315 
316  @debug
317  def test_foo(self):
318  raise Exception("Bar")
319 
320  '''
321 
322  @functools.wraps(f)
323  def wrapper(*args, **kwargs):
324  def func():
325  return f(*args, **kwargs)
326  DebugMode.run(func)
327 
328  return wrapper