3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from google.protobuf.message
import Message
9 from multiprocessing
import Process
11 from collections
import defaultdict
14 from past.builtins
import basestring
19 from caffe2.proto
import caffe2_pb2
24 logger = logging.getLogger(__name__)
27 CreateBlob = C.create_blob
28 CurrentWorkspace = C.current_workspace
29 DeserializeBlob = C.deserialize_blob
30 GlobalInit = C.global_init
32 RegisteredOperators = C.registered_operators
33 SerializeBlob = C.serialize_blob
34 SwitchWorkspace = C.switch_workspace
35 RootFolder = C.root_folder
36 Workspaces = C.workspaces
37 BenchmarkNet = C.benchmark_net
38 GetStats = C.get_stats
40 operator_tracebacks = defaultdict(dict)
43 has_gpu_support = C.has_gpu_support
45 NumCudaDevices = C.num_cuda_devices
46 GetCUDAVersion = C.get_cuda_version
47 GetCuDNNVersion = C.get_cudnn_version
49 def GetCudaPeerAccessPattern():
50 return np.asarray(C.get_cuda_peer_access_pattern())
52 GetDeviceProperties = C.get_device_properties
54 NumCudaDevices =
lambda: 0
55 GetCuDNNVersion =
lambda: 0
56 GetCuDNNVersion =
lambda: 0
57 GetCudaPeerAccessPattern =
lambda: np.array([])
58 GetDeviceProperties =
lambda x:
None 60 IsNUMAEnabled = C.is_numa_enabled
61 GetNumNUMANodes = C.get_num_numa_nodes
62 GetBlobNUMANode = C.get_blob_numa_node
64 def _GetFreeFlaskPort():
65 """Get a free flask port.""" 67 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
68 result = sock.connect_ex((
'127.0.0.1', 5000))
74 port = s.getsockname()[1]
83 def StartMint(root_folder=None, port=None):
84 """Start a mint instance. 86 TODO(Yangqing): this does not work well under ipython yet. According to 87 https://github.com/ipython/ipython/issues/5862 88 writing up some fix is a todo item. 90 from caffe2.python.mint
import app
91 if root_folder
is None:
93 root_folder = C.root_folder()
95 port = _GetFreeFlaskPort()
99 [
'-p', str(port),
'-r', root_folder],
103 print(
'Mint running at http://{}:{}'.format(socket.getfqdn(), port))
107 def StringifyProto(obj):
108 """Stringify a protocol buffer object. 111 obj: a protocol buffer object, or a Pycaffe2 object that has a Proto() 114 string: the output protobuf string. 116 AttributeError: if the passed in object does not have the right attribute. 118 if isinstance(obj, basestring):
121 if isinstance(obj, Message):
124 return obj.SerializeToString()
125 elif hasattr(obj,
'Proto'):
126 return obj.Proto().SerializeToString()
128 raise ValueError(
"Unexpected argument to StringifyProto of type " +
132 def ResetWorkspace(root_folder=None):
133 if root_folder
is None:
135 return C.reset_workspace(C.root_folder())
137 if not os.path.exists(root_folder):
138 os.makedirs(root_folder)
139 return C.reset_workspace(root_folder)
142 def CreateNet(net, overwrite=False, input_blobs=None):
143 if input_blobs
is None:
145 for input_blob
in input_blobs:
146 C.create_blob(input_blob)
147 return CallWithExceptionIntercept(
149 C.Workspace.current._last_failed_op_net_position,
151 StringifyProto(net), overwrite,
155 def Predictor(init_net, predict_net):
156 return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
159 def GetOperatorCost(operator, blobs):
160 return C.get_operator_cost(StringifyProto(operator), blobs)
163 def RunOperatorOnce(operator):
164 return C.run_operator_once(StringifyProto(operator))
167 def RunOperatorsOnce(operators):
169 success = RunOperatorOnce(op)
175 def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
177 return func(*args, **kwargs)
179 op_id = op_id_fetcher()
180 net_tracebacks = operator_tracebacks.get(net_name,
None)
181 print(
'Original python traceback for operator {} in network `{}` in ' 182 'exception above (most recent call last):'.format(
184 if net_tracebacks
and op_id
in net_tracebacks:
185 tb = net_tracebacks[op_id]
186 for line
in reversed(tb):
187 print(
' File "{}", line {}, in {}'.format(
188 line[0], line[1], line[2]))
193 return CallWithExceptionIntercept(
195 C.Workspace.current._last_failed_op_net_position,
201 def RunNet(name, num_iter=1, allow_fail=False):
205 name: the name of the net, or a reference to the net. 206 num_iter: number of iterations to run 207 allow_fail: if True, does not assert on net exec failure but returns False 209 True or an exception. 211 return CallWithExceptionIntercept(
213 C.Workspace.current._last_failed_op_net_position,
215 StringifyNetName(name), num_iter, allow_fail,
219 def RunPlan(plan_or_step):
222 if isinstance(plan_or_step, core.ExecutionStep):
223 plan_or_step = core.Plan(plan_or_step)
224 return C.run_plan(StringifyProto(plan_or_step))
227 def InferShapesAndTypes(nets, blob_dimensions=None):
228 """Infers the shapes and types for the specified nets. 231 nets: the list of nets 232 blob_dimensions (optional): a dictionary of blobs and their dimensions. 233 If not specified, the workspace blobs are used. 235 A tuple of (shapes, types) dictionaries keyed by blob name. 237 net_protos = [StringifyProto(n.Proto())
for n
in nets]
238 if blob_dimensions
is None:
239 blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
241 blobdesc_prototxt = C.infer_shapes_and_types_from_map(
242 net_protos, blob_dimensions
244 blobdesc_proto = caffe2_pb2.TensorShapes()
245 blobdesc_proto.ParseFromString(blobdesc_prototxt)
248 for ts
in blobdesc_proto.shapes:
249 if not ts.unknown_shape:
250 shapes[ts.name] = list(ts.dims)
251 types[ts.name] = ts.data_type
253 return (shapes, types)
256 def _StringifyName(name, expected_type):
257 if isinstance(name, basestring):
259 assert type(name).__name__ == expected_type, \
260 "Expected a string or %s" % expected_type
264 def StringifyBlobName(name):
265 return _StringifyName(name,
"BlobReference")
268 def StringifyNetName(name):
269 return _StringifyName(name,
"Net")
273 if isinstance(net, basestring):
275 if type(net).__name__ ==
"Net":
277 if isinstance(net, caffe2_pb2.NetDef):
279 raise Exception(
"Not a Net object: {}".format(str(net)))
282 def FeedBlob(name, arr, device_option=None):
283 """Feeds a blob into the workspace. 286 name: the name of the blob. 287 arr: either a TensorProto object or a numpy array object to be fed into 289 device_option (optional): the device option to feed the data with. 291 True or False, stating whether the feed is successful. 293 if type(arr)
is caffe2_pb2.TensorProto:
294 arr = utils.Caffe2TensorToNumpyArray(arr)
295 if type(arr)
is np.ndarray
and arr.dtype.kind
in 'SU':
297 arr = arr.astype(np.object)
299 if device_option
is None:
300 device_option = scope.CurrentDeviceScope()
302 if device_option
and device_option.device_type == caffe2_pb2.CUDA:
303 if arr.dtype == np.dtype(
'float64'):
305 "CUDA operators do not support 64-bit doubles, " +
306 "please use arr.astype(np.float32) or np.int32 for ints." +
307 " Blob: {}".format(name) +
308 " type: {}".format(str(arr.dtype))
311 name = StringifyBlobName(name)
312 if device_option
is not None:
313 return C.feed_blob(name, arr, StringifyProto(device_option))
315 return C.feed_blob(name, arr)
318 def FetchBlobs(names):
319 """Fetches a list of blobs from the workspace. 322 names: list of names of blobs - strings or BlobReferences 324 list of fetched blobs 326 return [FetchBlob(name)
for name
in names]
330 """Fetches a blob from the workspace. 333 name: the name of the blob - a string or a BlobReference 335 Fetched blob (numpy array or string) if successful 337 return C.fetch_blob(StringifyBlobName(name))
340 def ApplyTransform(transform_key, net):
341 """Apply a Transform to a NetDef protobuf object, and returns the new 345 transform_key: the name of the transform, as it is stored in the registry 346 net: a NetDef protobuf object 348 Transformed NetDef protobuf object. 350 transformed_net = caffe2_pb2.NetDef()
351 transformed_str = C.apply_transform(
352 str(transform_key).encode(
'utf-8'),
353 net.SerializeToString(),
355 transformed_net.ParseFromString(transformed_str)
356 return transformed_net
359 def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
360 """Apply a Transform to a NetDef protobuf object, and returns the new 361 transformed NetDef, only if it runs faster than the original. 363 The runs are performed on the current active workspace (gWorkspace). 364 You should initialize that workspace before making a call to this function. 367 transform_key: the name of the transform, as it is stored in the registry 368 net: a NetDef protobuf object 369 init_net: The net to initialize the workspace. 370 warmup_runs (optional): 371 Determines how many times the net is run before testing. 372 Will be 5 by default. 373 main_runs (optional): 374 Determines how many times the net is run during testing. 375 Will be 10 by default. 376 improvement_threshold (optional): 377 Determines the factor which the new net needs to be faster 378 in order to replace the old. Will be 1.01 by default. 381 Either a Transformed NetDef protobuf object, or the original netdef. 384 warmup_runs = kwargs[
'warmup_runs']
if 'warmup_runs' in kwargs
else 5
385 main_runs = kwargs[
'main_runs']
if 'main_runs' in kwargs
else 10
386 improvement_threshold = kwargs[
'improvement_threshold'] \
387 if 'improvement_threshold' in kwargs
else 1.01
389 transformed_net = caffe2_pb2.NetDef()
390 transformed_str = C.apply_transform_if_faster(
391 str(transform_key).encode(
'utf-8'),
392 net.SerializeToString(),
393 init_net.SerializeToString(),
396 float(improvement_threshold),
398 transformed_net.ParseFromString(transformed_str)
399 return transformed_net
403 """Return the current namescope string. To be used to fetch blobs""" 404 return scope.CurrentNameScope()
408 """Provides python dict compatible way to do fetching and feeding""" 410 def __getitem__(self, key):
411 return FetchBlob(key)
413 def __setitem__(self, key, value):
414 return FeedBlob(key, value)
417 return len(C.blobs())
420 return C.blobs().__iter__()
422 def __contains__(self, item):
423 return C.has_blob(item)
449 _immediate_mode =
False 450 _immediate_workspace_name =
"_CAFFE2_IMMEDIATE" 451 _immediate_root_folder =
'' 455 return _immediate_mode
458 @contextlib.contextmanager
459 def WorkspaceGuard(workspace_name):
460 current = CurrentWorkspace()
461 SwitchWorkspace(workspace_name,
True)
463 SwitchWorkspace(current)
466 def StartImmediate(i_know=False):
467 global _immediate_mode
468 global _immediate_root_folder
473 _immediate_mode =
True 474 with WorkspaceGuard(_immediate_workspace_name):
475 _immediate_root_folder = tempfile.mkdtemp()
476 ResetWorkspace(_immediate_root_folder)
481 Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL 482 feature and may very easily go wrong. This is because Caffe2 uses a 483 declarative way of defining operators and models, which is essentially 484 not meant to run things in an interactive way. Read the following carefully 485 to make sure that you understand the caveats. 487 (1) You need to make sure that the sequences of operators you create are 488 actually runnable sequentially. For example, if you create an op that takes 489 an input X, somewhere earlier you should have already created X. 491 (2) Caffe2 immediate uses one single workspace, so if the set of operators 492 you run are intended to be under different workspaces, they will not run. 493 To create boundaries between such use cases, you can call FinishImmediate() 494 and StartImmediate() manually to flush out everything no longer needed. 496 (3) Underlying objects held by the immediate mode may interfere with your 497 normal run. For example, if there is a leveldb that you opened in immediate 498 mode and did not close, your main run will fail because leveldb does not 499 support double opening. Immediate mode may also occupy a lot of memory esp. 500 on GPUs. Call FinishImmediate() as soon as possible when you no longer 503 (4) Immediate is designed to be slow. Every immediate call implicitly 504 creates a temp operator object, runs it, and destroys the operator. This 505 slow-speed run is by design to discourage abuse. For most use cases other 506 than debugging, do NOT turn on immediate mode. 508 (5) If there is anything FATAL happening in the underlying C++ code, the 509 immediate mode will immediately (pun intended) cause the runtime to crash. 511 Thus you should use immediate mode with extra care. If you still would 512 like to, have fun [https://xkcd.com/149/]. 517 """Stops an immediate mode run.""" 519 global _immediate_mode
520 global _immediate_root_folder
521 if not IsImmediate():
523 with WorkspaceGuard(_immediate_workspace_name):
525 shutil.rmtree(_immediate_root_folder)
526 _immediate_root_folder =
'' 527 _immediate_mode =
False 530 def ImmediateBlobs():
531 with WorkspaceGuard(_immediate_workspace_name):
535 def RunOperatorImmediate(op):
536 with WorkspaceGuard(_immediate_workspace_name):
540 def FetchImmediate(*args, **kwargs):
541 with WorkspaceGuard(_immediate_workspace_name):
542 return FetchBlob(*args, **kwargs)
545 def FeedImmediate(*args, **kwargs):
546 with WorkspaceGuard(_immediate_workspace_name):
547 return FeedBlob(*args, **kwargs)
552 def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
553 return CallWithExceptionIntercept(
555 ws._last_failed_op_net_position,
557 StringifyProto(net), overwrite,
561 C.Workspace.create_net = _Workspace_create_net_with_exception_intercept
564 def _Workspace_run(ws, obj):
565 if hasattr(obj,
'Proto'):
567 if isinstance(obj, caffe2_pb2.PlanDef):
568 return ws._run_plan(obj.SerializeToString())
569 if isinstance(obj, caffe2_pb2.NetDef):
570 return CallWithExceptionIntercept(
572 ws._last_failed_op_net_position,
574 obj.SerializeToString(),
577 if isinstance(obj, caffe2_pb2.OperatorDef):
578 return ws._run_operator(obj.SerializeToString())
580 "Don't know how to do Workspace.run() on {}".format(type(obj)))
583 C.Workspace.run = _Workspace_run
586 def _Blob_feed(blob, arg, device_option=None):
587 if device_option
is not None:
588 device_option = StringifyProto(device_option)
589 return blob._feed(arg, device_option)
592 C.Blob.feed = _Blob_feed