Caffe2 - Python API
A deep learning, cross platform ML framework
workspace.py
1 ## @package workspace
2 # Module caffe2.python.workspace
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import contextlib
8 from google.protobuf.message import Message
9 from multiprocessing import Process
10 import os
11 from collections import defaultdict
12 import logging
13 import numpy as np
14 from past.builtins import basestring
15 import shutil
16 import socket
17 import tempfile
18 
19 from caffe2.proto import caffe2_pb2
20 from caffe2.python import scope, utils
21 
23 
24 logger = logging.getLogger(__name__)
25 
26 Blobs = C.blobs
27 CreateBlob = C.create_blob
28 CurrentWorkspace = C.current_workspace
29 DeserializeBlob = C.deserialize_blob
30 GlobalInit = C.global_init
31 HasBlob = C.has_blob
32 RegisteredOperators = C.registered_operators
33 SerializeBlob = C.serialize_blob
34 SwitchWorkspace = C.switch_workspace
35 RootFolder = C.root_folder
36 Workspaces = C.workspaces
37 BenchmarkNet = C.benchmark_net
38 GetStats = C.get_stats
39 
40 operator_tracebacks = defaultdict(dict)
41 
42 is_asan = C.is_asan
43 has_gpu_support = C.has_gpu_support
44 if has_gpu_support:
45  NumCudaDevices = C.num_cuda_devices
46  GetCUDAVersion = C.get_cuda_version
47  GetCuDNNVersion = C.get_cudnn_version
48 
49  def GetCudaPeerAccessPattern():
50  return np.asarray(C.get_cuda_peer_access_pattern())
51 
52  GetDeviceProperties = C.get_device_properties
53 else:
54  NumCudaDevices = lambda: 0 # noqa
55  GetCuDNNVersion = lambda: 0 # noqa
56  GetCuDNNVersion = lambda: 0 # noqa
57  GetCudaPeerAccessPattern = lambda: np.array([]) # noqa
58  GetDeviceProperties = lambda x: None # noqa
59 
60 IsNUMAEnabled = C.is_numa_enabled
61 GetNumNUMANodes = C.get_num_numa_nodes
62 GetBlobNUMANode = C.get_blob_numa_node
63 
64 def _GetFreeFlaskPort():
65  """Get a free flask port."""
66  # We will prefer to use 5000. If not, we will then pick a random port.
67  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
68  result = sock.connect_ex(('127.0.0.1', 5000))
69  if result == 0:
70  return 5000
71  else:
72  s = socket.socket()
73  s.bind(('', 0))
74  port = s.getsockname()[1]
75  s.close()
76  # Race condition: between the interval we close the socket and actually
77  # start a mint process, another process might have occupied the port. We
78  # don't do much here as this is mostly for convenience in research
79  # rather than 24x7 service.
80  return port
81 
82 
83 def StartMint(root_folder=None, port=None):
84  """Start a mint instance.
85 
86  TODO(Yangqing): this does not work well under ipython yet. According to
87  https://github.com/ipython/ipython/issues/5862
88  writing up some fix is a todo item.
89  """
90  from caffe2.python.mint import app
91  if root_folder is None:
92  # Get the root folder from the current workspace
93  root_folder = C.root_folder()
94  if port is None:
95  port = _GetFreeFlaskPort()
96  process = Process(
97  target=app.main,
98  args=(
99  ['-p', str(port), '-r', root_folder],
100  )
101  )
102  process.start()
103  print('Mint running at http://{}:{}'.format(socket.getfqdn(), port))
104  return process
105 
106 
107 def StringifyProto(obj):
108  """Stringify a protocol buffer object.
109 
110  Inputs:
111  obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()
112  function.
113  Outputs:
114  string: the output protobuf string.
115  Raises:
116  AttributeError: if the passed in object does not have the right attribute.
117  """
118  if isinstance(obj, basestring):
119  return obj
120  else:
121  if isinstance(obj, Message):
122  # First, see if this object is a protocol buffer, which we can
123  # simply serialize with the SerializeToString() call.
124  return obj.SerializeToString()
125  elif hasattr(obj, 'Proto'):
126  return obj.Proto().SerializeToString()
127  else:
128  raise ValueError("Unexpected argument to StringifyProto of type " +
129  type(obj).__name__)
130 
131 
132 def ResetWorkspace(root_folder=None):
133  if root_folder is None:
134  # Reset the workspace, but keep the current root folder setting.
135  return C.reset_workspace(C.root_folder())
136  else:
137  if not os.path.exists(root_folder):
138  os.makedirs(root_folder)
139  return C.reset_workspace(root_folder)
140 
141 
142 def CreateNet(net, overwrite=False, input_blobs=None):
143  if input_blobs is None:
144  input_blobs = []
145  for input_blob in input_blobs:
146  C.create_blob(input_blob)
147  return CallWithExceptionIntercept(
148  C.create_net,
149  C.Workspace.current._last_failed_op_net_position,
150  GetNetName(net),
151  StringifyProto(net), overwrite,
152  )
153 
154 
155 def Predictor(init_net, predict_net):
156  return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
157 
158 
159 def GetOperatorCost(operator, blobs):
160  return C.get_operator_cost(StringifyProto(operator), blobs)
161 
162 
163 def RunOperatorOnce(operator):
164  return C.run_operator_once(StringifyProto(operator))
165 
166 
167 def RunOperatorsOnce(operators):
168  for op in operators:
169  success = RunOperatorOnce(op)
170  if not success:
171  return False
172  return True
173 
174 
175 def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
176  try:
177  return func(*args, **kwargs)
178  except Exception:
179  op_id = op_id_fetcher()
180  net_tracebacks = operator_tracebacks.get(net_name, None)
181  print('Original python traceback for operator {} in network `{}` in '
182  'exception above (most recent call last):'.format(
183  op_id, net_name))
184  if net_tracebacks and op_id in net_tracebacks:
185  tb = net_tracebacks[op_id]
186  for line in reversed(tb):
187  print(' File "{}", line {}, in {}'.format(
188  line[0], line[1], line[2]))
189  raise
190 
191 
192 def RunNetOnce(net):
193  return CallWithExceptionIntercept(
194  C.run_net_once,
195  C.Workspace.current._last_failed_op_net_position,
196  GetNetName(net),
197  StringifyProto(net),
198  )
199 
200 
201 def RunNet(name, num_iter=1, allow_fail=False):
202  """Runs a given net.
203 
204  Inputs:
205  name: the name of the net, or a reference to the net.
206  num_iter: number of iterations to run
207  allow_fail: if True, does not assert on net exec failure but returns False
208  Returns:
209  True or an exception.
210  """
211  return CallWithExceptionIntercept(
212  C.run_net,
213  C.Workspace.current._last_failed_op_net_position,
214  GetNetName(name),
215  StringifyNetName(name), num_iter, allow_fail,
216  )
217 
218 
219 def RunPlan(plan_or_step):
220  # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
221  import caffe2.python.core as core
222  if isinstance(plan_or_step, core.ExecutionStep):
223  plan_or_step = core.Plan(plan_or_step)
224  return C.run_plan(StringifyProto(plan_or_step))
225 
226 
227 def InferShapesAndTypes(nets, blob_dimensions=None):
228  """Infers the shapes and types for the specified nets.
229 
230  Inputs:
231  nets: the list of nets
232  blob_dimensions (optional): a dictionary of blobs and their dimensions.
233  If not specified, the workspace blobs are used.
234  Returns:
235  A tuple of (shapes, types) dictionaries keyed by blob name.
236  """
237  net_protos = [StringifyProto(n.Proto()) for n in nets]
238  if blob_dimensions is None:
239  blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
240  else:
241  blobdesc_prototxt = C.infer_shapes_and_types_from_map(
242  net_protos, blob_dimensions
243  )
244  blobdesc_proto = caffe2_pb2.TensorShapes()
245  blobdesc_proto.ParseFromString(blobdesc_prototxt)
246  shapes = {}
247  types = {}
248  for ts in blobdesc_proto.shapes:
249  if not ts.unknown_shape:
250  shapes[ts.name] = list(ts.dims)
251  types[ts.name] = ts.data_type
252 
253  return (shapes, types)
254 
255 
256 def _StringifyName(name, expected_type):
257  if isinstance(name, basestring):
258  return name
259  assert type(name).__name__ == expected_type, \
260  "Expected a string or %s" % expected_type
261  return str(name)
262 
263 
264 def StringifyBlobName(name):
265  return _StringifyName(name, "BlobReference")
266 
267 
268 def StringifyNetName(name):
269  return _StringifyName(name, "Net")
270 
271 
272 def GetNetName(net):
273  if isinstance(net, basestring):
274  return net
275  if type(net).__name__ == "Net":
276  return net.Name()
277  if isinstance(net, caffe2_pb2.NetDef):
278  return net.name
279  raise Exception("Not a Net object: {}".format(str(net)))
280 
281 
282 def FeedBlob(name, arr, device_option=None):
283  """Feeds a blob into the workspace.
284 
285  Inputs:
286  name: the name of the blob.
287  arr: either a TensorProto object or a numpy array object to be fed into
288  the workspace.
289  device_option (optional): the device option to feed the data with.
290  Returns:
291  True or False, stating whether the feed is successful.
292  """
293  if type(arr) is caffe2_pb2.TensorProto:
294  arr = utils.Caffe2TensorToNumpyArray(arr)
295  if type(arr) is np.ndarray and arr.dtype.kind in 'SU':
296  # Plain NumPy strings are weird, let's use objects instead
297  arr = arr.astype(np.object)
298 
299  if device_option is None:
300  device_option = scope.CurrentDeviceScope()
301 
302  if device_option and device_option.device_type == caffe2_pb2.CUDA:
303  if arr.dtype == np.dtype('float64'):
304  logger.warning(
305  "CUDA operators do not support 64-bit doubles, " +
306  "please use arr.astype(np.float32) or np.int32 for ints." +
307  " Blob: {}".format(name) +
308  " type: {}".format(str(arr.dtype))
309  )
310 
311  name = StringifyBlobName(name)
312  if device_option is not None:
313  return C.feed_blob(name, arr, StringifyProto(device_option))
314  else:
315  return C.feed_blob(name, arr)
316 
317 
318 def FetchBlobs(names):
319  """Fetches a list of blobs from the workspace.
320 
321  Inputs:
322  names: list of names of blobs - strings or BlobReferences
323  Returns:
324  list of fetched blobs
325  """
326  return [FetchBlob(name) for name in names]
327 
328 
329 def FetchBlob(name):
330  """Fetches a blob from the workspace.
331 
332  Inputs:
333  name: the name of the blob - a string or a BlobReference
334  Returns:
335  Fetched blob (numpy array or string) if successful
336  """
337  return C.fetch_blob(StringifyBlobName(name))
338 
339 
340 def ApplyTransform(transform_key, net):
341  """Apply a Transform to a NetDef protobuf object, and returns the new
342  transformed NetDef.
343 
344  Inputs:
345  transform_key: the name of the transform, as it is stored in the registry
346  net: a NetDef protobuf object
347  Returns:
348  Transformed NetDef protobuf object.
349  """
350  transformed_net = caffe2_pb2.NetDef()
351  transformed_str = C.apply_transform(
352  str(transform_key).encode('utf-8'),
353  net.SerializeToString(),
354  )
355  transformed_net.ParseFromString(transformed_str)
356  return transformed_net
357 
358 
359 def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
360  """Apply a Transform to a NetDef protobuf object, and returns the new
361  transformed NetDef, only if it runs faster than the original.
362 
363  The runs are performed on the current active workspace (gWorkspace).
364  You should initialize that workspace before making a call to this function.
365 
366  Inputs:
367  transform_key: the name of the transform, as it is stored in the registry
368  net: a NetDef protobuf object
369  init_net: The net to initialize the workspace.
370  warmup_runs (optional):
371  Determines how many times the net is run before testing.
372  Will be 5 by default.
373  main_runs (optional):
374  Determines how many times the net is run during testing.
375  Will be 10 by default.
376  improvement_threshold (optional):
377  Determines the factor which the new net needs to be faster
378  in order to replace the old. Will be 1.01 by default.
379 
380  Returns:
381  Either a Transformed NetDef protobuf object, or the original netdef.
382  """
383 
384  warmup_runs = kwargs['warmup_runs'] if 'warmup_runs' in kwargs else 5
385  main_runs = kwargs['main_runs'] if 'main_runs' in kwargs else 10
386  improvement_threshold = kwargs['improvement_threshold'] \
387  if 'improvement_threshold' in kwargs else 1.01
388 
389  transformed_net = caffe2_pb2.NetDef()
390  transformed_str = C.apply_transform_if_faster(
391  str(transform_key).encode('utf-8'),
392  net.SerializeToString(),
393  init_net.SerializeToString(),
394  warmup_runs,
395  main_runs,
396  float(improvement_threshold),
397  )
398  transformed_net.ParseFromString(transformed_str)
399  return transformed_net
400 
401 
402 def GetNameScope():
403  """Return the current namescope string. To be used to fetch blobs"""
404  return scope.CurrentNameScope()
405 
406 
407 class _BlobDict(object):
408  """Provides python dict compatible way to do fetching and feeding"""
409 
410  def __getitem__(self, key):
411  return FetchBlob(key)
412 
413  def __setitem__(self, key, value):
414  return FeedBlob(key, value)
415 
416  def __len__(self):
417  return len(C.blobs())
418 
419  def __iter__(self):
420  return C.blobs().__iter__()
421 
422  def __contains__(self, item):
423  return C.has_blob(item)
424 
425 
426 blobs = _BlobDict()
427 
428 
429 ################################################################################
430 # Utilities for immediate mode
431 #
432 # Caffe2's immediate mode implements the following behavior: between the two
433 # function calls StartImmediate() and StopImmediate(), for any operator that is
434 # called through CreateOperator(), we will also run that operator in a workspace
435 # that is specific to the immediate mode. The user is explicitly expected to
436 # make sure that these ops have proper inputs and outputs, i.e. one should not
437 # run an op where an external input is not created or fed.
438 #
439 # Users can use FeedImmediate() and FetchImmediate() to interact with blobs
440 # in the immediate workspace.
441 #
442 # Once StopImmediate() is called, all contents in the immediate workspace is
443 # freed up so one can continue using normal runs.
444 #
445 # The immediate mode is solely for debugging purposes and support will be very
446 # sparse.
447 ################################################################################
448 
449 _immediate_mode = False
450 _immediate_workspace_name = "_CAFFE2_IMMEDIATE"
451 _immediate_root_folder = ''
452 
453 
454 def IsImmediate():
455  return _immediate_mode
456 
457 
458 @contextlib.contextmanager
459 def WorkspaceGuard(workspace_name):
460  current = CurrentWorkspace()
461  SwitchWorkspace(workspace_name, True)
462  yield
463  SwitchWorkspace(current)
464 
465 
466 def StartImmediate(i_know=False):
467  global _immediate_mode
468  global _immediate_root_folder
469  if IsImmediate():
470  # already in immediate mode. We will kill the previous one
471  # and start from fresh.
472  StopImmediate()
473  _immediate_mode = True
474  with WorkspaceGuard(_immediate_workspace_name):
475  _immediate_root_folder = tempfile.mkdtemp()
476  ResetWorkspace(_immediate_root_folder)
477  if i_know:
478  # if the user doesn't want to see the warning message, sure...
479  return
480  print("""
481  Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL
482  feature and may very easily go wrong. This is because Caffe2 uses a
483  declarative way of defining operators and models, which is essentially
484  not meant to run things in an interactive way. Read the following carefully
485  to make sure that you understand the caveats.
486 
487  (1) You need to make sure that the sequences of operators you create are
488  actually runnable sequentially. For example, if you create an op that takes
489  an input X, somewhere earlier you should have already created X.
490 
491  (2) Caffe2 immediate uses one single workspace, so if the set of operators
492  you run are intended to be under different workspaces, they will not run.
493  To create boundaries between such use cases, you can call FinishImmediate()
494  and StartImmediate() manually to flush out everything no longer needed.
495 
496  (3) Underlying objects held by the immediate mode may interfere with your
497  normal run. For example, if there is a leveldb that you opened in immediate
498  mode and did not close, your main run will fail because leveldb does not
499  support double opening. Immediate mode may also occupy a lot of memory esp.
500  on GPUs. Call FinishImmediate() as soon as possible when you no longer
501  need it.
502 
503  (4) Immediate is designed to be slow. Every immediate call implicitly
504  creates a temp operator object, runs it, and destroys the operator. This
505  slow-speed run is by design to discourage abuse. For most use cases other
506  than debugging, do NOT turn on immediate mode.
507 
508  (5) If there is anything FATAL happening in the underlying C++ code, the
509  immediate mode will immediately (pun intended) cause the runtime to crash.
510 
511  Thus you should use immediate mode with extra care. If you still would
512  like to, have fun [https://xkcd.com/149/].
513  """)
514 
515 
516 def StopImmediate():
517  """Stops an immediate mode run."""
518  # Phew, that was a dangerous ride.
519  global _immediate_mode
520  global _immediate_root_folder
521  if not IsImmediate():
522  return
523  with WorkspaceGuard(_immediate_workspace_name):
524  ResetWorkspace()
525  shutil.rmtree(_immediate_root_folder)
526  _immediate_root_folder = ''
527  _immediate_mode = False
528 
529 
530 def ImmediateBlobs():
531  with WorkspaceGuard(_immediate_workspace_name):
532  return Blobs()
533 
534 
535 def RunOperatorImmediate(op):
536  with WorkspaceGuard(_immediate_workspace_name):
537  RunOperatorOnce(op)
538 
539 
540 def FetchImmediate(*args, **kwargs):
541  with WorkspaceGuard(_immediate_workspace_name):
542  return FetchBlob(*args, **kwargs)
543 
544 
545 def FeedImmediate(*args, **kwargs):
546  with WorkspaceGuard(_immediate_workspace_name):
547  return FeedBlob(*args, **kwargs)
548 
549 
550 # CWorkspace utilities
551 
552 def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
553  return CallWithExceptionIntercept(
554  ws._create_net,
555  ws._last_failed_op_net_position,
556  GetNetName(net),
557  StringifyProto(net), overwrite,
558  )
559 
560 
561 C.Workspace.create_net = _Workspace_create_net_with_exception_intercept
562 
563 
564 def _Workspace_run(ws, obj):
565  if hasattr(obj, 'Proto'):
566  obj = obj.Proto()
567  if isinstance(obj, caffe2_pb2.PlanDef):
568  return ws._run_plan(obj.SerializeToString())
569  if isinstance(obj, caffe2_pb2.NetDef):
570  return CallWithExceptionIntercept(
571  ws._run_net,
572  ws._last_failed_op_net_position,
573  GetNetName(obj),
574  obj.SerializeToString(),
575  )
576  # return ws._run_net(obj.SerializeToString())
577  if isinstance(obj, caffe2_pb2.OperatorDef):
578  return ws._run_operator(obj.SerializeToString())
579  raise ValueError(
580  "Don't know how to do Workspace.run() on {}".format(type(obj)))
581 
582 
583 C.Workspace.run = _Workspace_run
584 
585 
586 def _Blob_feed(blob, arg, device_option=None):
587  if device_option is not None:
588  device_option = StringifyProto(device_option)
589  return blob._feed(arg, device_option)
590 
591 
592 C.Blob.feed = _Blob_feed