Caffe2 - Python API
A deep learning, cross platform ML framework
tags.py
1 ## @package tags
2 # Module caffe2.python.layers.tags
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import six
9 
10 from caffe2.python import context
11 
12 
13 @context.define_context(allow_default=True)
14 class TagContext(object):
15  """
16  Scope driven way to provide tags to the layers.
17  """
18 
19  def __init__(self, tags=None):
20  # Tags is expected to be list to keep order of adding/removing things
21  self.tags = tags or []
22 
23  def add_tags(self, tags):
24  self.tags.extend(tags)
25 
26  def remove_tags(self, tags):
27  assert self.tags[-len(tags):] == tags
28  self.tags = self.tags[:-len(tags)]
29 
30 
31 class Tags(object):
32  # TODO(amalevich): Tags might need to live in their own contexts, add this
33  # split later
34  EXCLUDE_FROM_TRAIN = 'exclude_from_train'
35  EXCLUDE_FROM_EVAL = 'exclude_from_eval'
36  EXCLUDE_FROM_PREDICTION = 'exclude_from_prediction'
37  EXCLUDE_FROM_ACCUMULATE_PRED = 'exclude_from_accumulate_pred'
38  PREPROCESSING = 'preprocessing'
39  HANDLE_AS_SPARSE_LAYER = 'handle_as_sparse_layer'
40  GRADIENT_FROM_PS = 'gradient_from_ps'
41  PREFER_GPU = 'prefer_gpu'
42  CPU_ONLY = 'cpu_only'
43 
44  # The following three tags are hints to **distributed training framework**.
45  """
46  Indicates a layer contains a sparse shardable parameter. The parameter
47  should be sharded nd operators on those parameters should be done on
48  distributed parameter servers.
49  """
50  SPARSE_SHARDED = 'sparse_sharded'
51  """
52  Indicates a layer contains a sparse parameters among others, and that the
53  parameters should not be sharded (i.e. should be placed together on a node).
54  """
55  SPARSE_DONT_SHARD = 'sparse_dont_shard'
56  """
57  Used to manually indicate a component for an operator. Parameters for
58  all operators with the same component should be colocated on the same
59  parameter server.
60  """
61  COMPONENT = 'component:'
62  """
63  Valid tag prefixes for distributed training framework.
64  """
65  DT_TAGS = (SPARSE_SHARDED, SPARSE_DONT_SHARD, COMPONENT)
66 
67  # In certain cases we want to have different schema for training and
68  # prediction, as an example in prediction we might need to have only
69  # subset of ids present in the orignal schema. This tag is one of the ways
70  # to mark operators that will be removed from prediction and should
71  # override schema for predictors.
72  PREDICTION_SCHEMA = 'prediction_schema'
73 
74  def __init__(self, tags):
75  if not isinstance(tags, list):
76  tags = [tags]
77  self.tags = tags
78 
79  def __enter__(self):
80  TagContext.current().add_tags(self.tags)
81  return self
82 
83  def __exit__(self, type, value, traceback):
84  TagContext.current().remove_tags(self.tags)
85 
86  def __call__(self, func):
87  @six.wraps(func)
88  def wrapper(*args, **kwargs):
89  with self:
90  return func(*args, **kwargs)
91  return wrapper
92 
93 
94 Tags.TRAIN_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_EVAL,
95  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
96 Tags.EVAL_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_TRAIN,
97  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
98 Tags.PREDICTION_ONLY = [Tags.EXCLUDE_FROM_TRAIN, Tags.EXCLUDE_FROM_EVAL,
99  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]