3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto
import caffe2_pb2
9 from future.utils
import viewitems
10 from google.protobuf.message
import DecodeError, Message
11 from google.protobuf
import text_format
18 from six
import integer_types, binary_type, text_type
21 def OpAlmostEqual(op_a, op_b, ignore_fields=None):
23 Two ops are identical except for each field in the `ignore_fields`. 25 ignore_fields = ignore_fields
or []
26 if not isinstance(ignore_fields, list):
27 ignore_fields = [ignore_fields]
29 assert all(isinstance(f, text_type)
for f
in ignore_fields), (
30 'Expect each field is text type, but got {}'.format(ignore_fields))
33 op = copy.deepcopy(op)
34 for field
in ignore_fields:
35 if op.HasField(field):
44 def CaffeBlobToNumpyArray(blob):
47 return (np.asarray(blob.data, dtype=np.float32)
48 .reshape(blob.num, blob.channels, blob.height, blob.width))
51 return (np.asarray(blob.data, dtype=np.float32)
52 .reshape(blob.shape.dim))
55 def Caffe2TensorToNumpyArray(tensor):
56 if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
58 tensor.float_data, dtype=np.float32).reshape(tensor.dims)
59 elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
61 tensor.double_data, dtype=np.float64).reshape(tensor.dims)
62 elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
64 tensor.int32_data, dtype=np.int).reshape(tensor.dims)
65 elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
67 tensor.int32_data, dtype=np.int16).reshape(tensor.dims)
68 elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
70 tensor.int32_data, dtype=np.uint16).reshape(tensor.dims)
71 elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
73 tensor.int32_data, dtype=np.int8).reshape(tensor.dims)
74 elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
76 tensor.int32_data, dtype=np.uint8).reshape(tensor.dims)
80 "Tensor data type not supported yet: " + str(tensor.data_type))
83 def NumpyArrayToCaffe2Tensor(arr, name=None):
84 tensor = caffe2_pb2.TensorProto()
85 tensor.dims.extend(arr.shape)
88 if arr.dtype == np.float32:
89 tensor.data_type = caffe2_pb2.TensorProto.FLOAT
90 tensor.float_data.extend(list(arr.flatten().astype(float)))
91 elif arr.dtype == np.float64:
92 tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
93 tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
94 elif arr.dtype == np.int
or arr.dtype == np.int32:
95 tensor.data_type = caffe2_pb2.TensorProto.INT32
96 tensor.int32_data.extend(arr.flatten().astype(np.int).tolist())
97 elif arr.dtype == np.int16:
98 tensor.data_type = caffe2_pb2.TensorProto.INT16
99 tensor.int32_data.extend(list(arr.flatten().astype(np.int16)))
100 elif arr.dtype == np.uint16:
101 tensor.data_type = caffe2_pb2.TensorProto.UINT16
102 tensor.int32_data.extend(list(arr.flatten().astype(np.uint16)))
103 elif arr.dtype == np.int8:
104 tensor.data_type = caffe2_pb2.TensorProto.INT8
105 tensor.int32_data.extend(list(arr.flatten().astype(np.int8)))
106 elif arr.dtype == np.uint8:
107 tensor.data_type = caffe2_pb2.TensorProto.UINT8
108 tensor.int32_data.extend(list(arr.flatten().astype(np.uint8)))
112 "Numpy data type not supported yet: " + str(arr.dtype))
116 def MakeArgument(key, value):
117 """Makes an argument based on the value type.""" 118 argument = caffe2_pb2.Argument()
120 iterable = isinstance(value, collections.Iterable)
126 if isinstance(value, np.ndarray)
and value.dtype.type
is np.float32:
127 argument.floats.extend(value.flatten().tolist())
130 if isinstance(value, np.ndarray):
131 value = value.flatten().tolist()
132 elif isinstance(value, np.generic):
134 value = np.asscalar(value)
136 if type(value)
is float:
138 elif type(value)
in integer_types
or type(value)
is bool:
142 elif isinstance(value, binary_type):
144 elif isinstance(value, text_type):
145 argument.s = value.encode(
'utf-8')
146 elif isinstance(value, caffe2_pb2.NetDef):
147 argument.n.CopyFrom(value)
148 elif isinstance(value, Message):
149 argument.s = value.SerializeToString()
150 elif iterable
and all(type(v)
in [float, np.float_]
for v
in value):
151 argument.floats.extend(
152 v.item()
if type(v)
is np.float_
else v
for v
in value
154 elif iterable
and all(
155 type(v)
in integer_types
or type(v)
in [bool, np.int_]
for v
in value
157 argument.ints.extend(
158 v.item()
if type(v)
is np.int_
else v
for v
in value
160 elif iterable
and all(
161 isinstance(v, binary_type)
or isinstance(v, text_type)
for v
in value
163 argument.strings.extend(
164 v.encode(
'utf-8')
if isinstance(v, text_type)
else v
167 elif iterable
and all(isinstance(v, caffe2_pb2.NetDef)
for v
in value):
168 argument.nets.extend(value)
169 elif iterable
and all(isinstance(v, Message)
for v
in value):
170 argument.strings.extend(v.SerializeToString()
for v
in value)
174 "Unknown iterable argument type: key={} value={}, value " 175 "type={}[{}]".format(
176 key, value, type(value), set(type(v)
for v
in value)
181 "Unknown argument type: key={} value={}, value type={}".format(
182 key, value, type(value)
188 def TryReadProtoWithClass(cls, s):
189 """Reads a protobuffer with the given proto class. 192 cls: a protobuffer class. 193 s: a string of either binary or text protobuffer content. 196 proto: the protobuffer of cls 199 google.protobuf.message.DecodeError: if we cannot decode the message. 203 text_format.Parse(s, obj)
205 except text_format.ParseError:
206 obj.ParseFromString(s)
210 def GetContentFromProto(obj, function_map):
211 """Gets a specific field from a protocol buffer that matches the given class 213 for cls, func
in viewitems(function_map):
218 def GetContentFromProtoString(s, function_map):
219 for cls, func
in viewitems(function_map):
221 obj = TryReadProtoWithClass(cls, s)
226 raise DecodeError(
"Cannot find a fit protobuffer class.")
229 def ConvertProtoToBinary(proto_class, filename, out_filename):
230 """Convert a text file of the given protobuf class to binary.""" 231 proto = TryReadProtoWithClass(proto_class, open(filename).read())
232 with open(out_filename,
'w')
as fid:
233 fid.write(proto.SerializeToString())
236 def GetGPUMemoryUsageStats():
237 """Get GPU memory usage stats from CUDAContext. This requires flag 238 --caffe2_gpu_memory_tracking to be enabled""" 240 workspace.RunOperatorOnce(
245 device_option=core.DeviceOption(caffe2_pb2.CUDA, 0),
248 b = workspace.FetchBlob(
"____mem____")
250 'total_by_gpu': b[0, :],
251 'max_by_gpu': b[1, :],
252 'total': np.sum(b[0, :]),
253 'max_total': np.sum(b[1, :])
257 def ResetBlobs(blobs):
259 workspace.RunOperatorOnce(
264 device_option=core.DeviceOption(caffe2_pb2.CPU),
271 This class allows to drop you into an interactive debugger 272 if there is an unhandled exception in your python script 280 if __name__ == '__main__': 281 from caffe2.python.utils import DebugMode 289 except KeyboardInterrupt:
295 'Entering interactive debugger. Type "bt" to print ' 296 'the full stacktrace. Type "help" to see command listing.')
297 print(sys.exc_info()[1])
305 def raiseIfNotEqual(a, b, msg):
307 raise Exception(
"{}. {} != {}".format(msg, a, b))
312 Use this method to decorate your function with DebugMode's functionality 318 raise Exception("Bar") 323 def wrapper(*args, **kwargs):
325 return f(*args, **kwargs)