3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 Collect samples from input record w/ reservoir sampling. If you have complex 15 data, use PackRecords to pack it before using this layer. 17 This layer is not thread safe. 20 def __init__(self, model, input_record, num_to_collect,
21 name=
'reservoir_sampling', **kwargs):
22 super(ReservoirSampling, self).__init__(
23 model, name, input_record, **kwargs)
24 assert num_to_collect > 0
28 param_name=
'reservoir',
30 initializer=(
'ConstantFill',),
31 optimizer=model.NoOptim,
34 param_name=
'num_visited',
36 initializer=(
'ConstantFill', {
38 'dtype': core.DataType.INT64,
40 optimizer=model.NoOptim,
45 initializer=(
'CreateMutex',),
46 optimizer=model.NoOptim,
51 if 'object_id' in input_record:
53 param_name=
'object_to_pos',
54 initializer=(
'CreateMap', {
55 'key_dtype': core.DataType.INT64,
56 'valued_dtype': core.DataType.INT32,
58 optimizer=model.NoOptim,
61 param_name=
'pos_to_object',
63 initializer=(
'ConstantFill', {
65 'dtype': core.DataType.INT64,
67 optimizer=model.NoOptim,
69 self.extra_input_blobs.append(input_record.object_id())
70 self.extra_input_blobs.extend([object_to_pos, pos_to_object])
71 self.extra_output_blobs.extend([object_to_pos, pos_to_object])
76 schema.from_blob_list(input_record.data, [self.
reservoir])
82 def add_ops(self, net):
83 net.ReservoirSampling(
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)