3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
13 @context.define_context(allow_default=
True)
16 Scope driven way to provide tags to the layers. 19 def __init__(self, tags=None):
21 self.
tags = tags
or []
23 def add_tags(self, tags):
24 self.tags.extend(tags)
26 def remove_tags(self, tags):
27 assert self.
tags[-len(tags):] == tags
34 EXCLUDE_FROM_TRAIN =
'exclude_from_train' 35 EXCLUDE_FROM_EVAL =
'exclude_from_eval' 36 EXCLUDE_FROM_PREDICTION =
'exclude_from_prediction' 37 EXCLUDE_FROM_ACCUMULATE_PRED =
'exclude_from_accumulate_pred' 38 PREPROCESSING =
'preprocessing' 39 HANDLE_AS_SPARSE_LAYER =
'handle_as_sparse_layer' 40 GRADIENT_FROM_PS =
'gradient_from_ps' 41 PREFER_GPU =
'prefer_gpu' 46 Indicates a layer contains a sparse shardable parameter. The parameter 47 should be sharded nd operators on those parameters should be done on 48 distributed parameter servers. 50 SPARSE_SHARDED =
'sparse_sharded' 52 Indicates a layer contains a sparse parameters among others, and that the 53 parameters should not be sharded (i.e. should be placed together on a node). 55 SPARSE_DONT_SHARD =
'sparse_dont_shard' 57 Used to manually indicate a component for an operator. Parameters for 58 all operators with the same component should be colocated on the same 61 COMPONENT =
'component:' 63 Valid tag prefixes for distributed training framework. 65 DT_TAGS = (SPARSE_SHARDED, SPARSE_DONT_SHARD, COMPONENT)
72 PREDICTION_SCHEMA =
'prediction_schema' 74 def __init__(self, tags):
75 if not isinstance(tags, list):
80 TagContext.current().add_tags(self.
tags)
83 def __exit__(self, type, value, traceback):
84 TagContext.current().remove_tags(self.
tags)
86 def __call__(self, func):
88 def wrapper(*args, **kwargs):
90 return func(*args, **kwargs)
94 Tags.TRAIN_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_EVAL,
95 Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
96 Tags.EVAL_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_TRAIN,
97 Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
98 Tags.PREDICTION_ONLY = [Tags.EXCLUDE_FROM_TRAIN, Tags.EXCLUDE_FROM_EVAL,
99 Tags.EXCLUDE_FROM_ACCUMULATE_PRED]