Caffe2 - Python API
A deep learning, cross platform ML framework
scope.py
1 ## @package scope
2 # Module caffe2.python.scope
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import contextlib
9 import threading
10 from past.builtins import basestring
11 
12 from caffe2.proto import caffe2_pb2
13 
14 
15 # The name scope and device scope when creating a new operator.
16 _NAMESCOPE_SEPARATOR = '/'
17 
18 _threadlocal_scope = threading.local()
19 
20 
21 def CurrentNameScope():
22  global _threadlocal_scope
23  if not hasattr(_threadlocal_scope, "namescope"):
24  _threadlocal_scope.namescope = ''
25  return _threadlocal_scope.namescope
26 
27 
28 def CurrentDeviceScope():
29  global _threadlocal_scope
30  if not hasattr(_threadlocal_scope, "devicescope"):
31  _threadlocal_scope.devicescope = None
32  return _threadlocal_scope.devicescope
33 
34 
35 @contextlib.contextmanager
36 def NameScope(prefix, reset=False):
37  global _threadlocal_scope
38  assert isinstance(prefix, basestring) or prefix is None, \
39  "NameScope takes in a string as its argument."
40  old_scope = CurrentNameScope()
41  prefix = prefix + _NAMESCOPE_SEPARATOR if prefix else ''
42  if reset:
43  _threadlocal_scope.namescope = prefix
44  else:
45  _threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix
46 
47  try:
48  yield
49  finally:
50  assert _threadlocal_scope.namescope.endswith(prefix), \
51  "The namescope variable is changed from outside NameScope() calls."
52  _threadlocal_scope.namescope = old_scope
53 
54 
55 @contextlib.contextmanager
56 def DeviceScope(scope, node_name=None):
57  new_scope = caffe2_pb2.DeviceOption()
58  if scope:
59  assert isinstance(scope, caffe2_pb2.DeviceOption), \
60  "DeviceScope takes in a caffe2_pb2.DeviceOption as its argument."
61  new_scope.CopyFrom(scope)
62  else:
63  assert node_name, "At least one argument should be non-null in DeviceScope"
64 
65  # rewrite node_name if it is explicitly given
66  if node_name:
67  new_scope.node_name = node_name
68  global _threadlocal_scope
69  old_scope = CurrentDeviceScope()
70  # nested scope should inherit the node_name if it is not explicitly set
71  if old_scope and old_scope.HasField('node_name') and \
72  not new_scope.HasField('node_name'):
73  new_scope.node_name = old_scope.node_name
74  _threadlocal_scope.devicescope = new_scope
75  try:
76  yield
77  finally:
78  assert _threadlocal_scope.devicescope == new_scope, \
79  "The device scope is changed from outside DeviceScope() calls."
80  _threadlocal_scope.devicescope = old_scope
81 
82 
83 @contextlib.contextmanager
84 def EmptyDeviceScope():
85  """
86  Allow users to 'disable' the device scope behaviour (so it can be
87  controlled at a NetDef::DeviceOption level, not overridden at
88  OperatorDef::DeviceOption level).
89 
90  This sets the CurrentDeviceScope() to None, so that the field is
91  not set in CreateOperator(...), etc.
92  """
93  old_scope = CurrentDeviceScope()
94  try:
95  _threadlocal_scope.devicescope = None
96  yield
97  finally:
98  _threadlocal_scope.devicescope = old_scope
99  return