Caffe2 - Python API
A deep learning, cross platform ML framework
reservoir_sampling.py
1 ## @package reservoir_sampling
2 # Module caffe2.python.layers.reservoir_sampling
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import ModelLayer
10 
11 
13  """
14  Collect samples from input record w/ reservoir sampling. If you have complex
15  data, use PackRecords to pack it before using this layer.
16 
17  This layer is not thread safe.
18  """
19 
20  def __init__(self, model, input_record, num_to_collect,
21  name='reservoir_sampling', **kwargs):
22  super(ReservoirSampling, self).__init__(
23  model, name, input_record, **kwargs)
24  assert num_to_collect > 0
25  self.num_to_collect = num_to_collect
26 
27  self.reservoir = self.create_param(
28  param_name='reservoir',
29  shape=[0],
30  initializer=('ConstantFill',),
31  optimizer=model.NoOptim,
32  )
33  self.num_visited_blob = self.create_param(
34  param_name='num_visited',
35  shape=[],
36  initializer=('ConstantFill', {
37  'value': 0,
38  'dtype': core.DataType.INT64,
39  }),
40  optimizer=model.NoOptim,
41  )
42  self.mutex = self.create_param(
43  param_name='mutex',
44  shape=None,
45  initializer=('CreateMutex',),
46  optimizer=model.NoOptim,
47  )
48 
49  self.extra_input_blobs = []
50  self.extra_output_blobs = []
51  if 'object_id' in input_record:
52  object_to_pos = self.create_param(
53  param_name='object_to_pos',
54  initializer=('CreateMap', {
55  'key_dtype': core.DataType.INT64,
56  'valued_dtype': core.DataType.INT32,
57  }),
58  optimizer=model.NoOptim,
59  )
60  pos_to_object = self.create_param(
61  param_name='pos_to_object',
62  shape=[0],
63  initializer=('ConstantFill', {
64  'value': 0,
65  'dtype': core.DataType.INT64,
66  }),
67  optimizer=model.NoOptim,
68  )
69  self.extra_input_blobs.append(input_record.object_id())
70  self.extra_input_blobs.extend([object_to_pos, pos_to_object])
71  self.extra_output_blobs.extend([object_to_pos, pos_to_object])
72 
74  (
75  'reservoir',
76  schema.from_blob_list(input_record.data, [self.reservoir])
77  ),
78  ('num_visited', schema.Scalar(blob=self.num_visited_blob)),
79  ('mutex', schema.Scalar(blob=self.mutex)),
80  )
81 
82  def add_ops(self, net):
83  net.ReservoirSampling(
84  [self.reservoir, self.num_visited_blob, self.input_record.data(),
85  self.mutex] + self.extra_input_blobs,
86  [self.reservoir, self.num_visited_blob] + self.extra_output_blobs,
87  num_to_collect=self.num_to_collect,
88  )
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)
Definition: layers.py:331