Caffe2 - C++ API
A deep learning, cross platform ML framework
reservoir_sampling.cc
1 #include <memory>
2 #include <string>
3 #include <vector>
4 #include "caffe2/core/operator.h"
5 #include "caffe2/core/tensor.h"
6 #include "caffe2/operators/map_ops.h"
7 
8 namespace caffe2 {
9 namespace {
10 
11 template <class Context>
12 class ReservoirSamplingOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  ReservoirSamplingOp(const OperatorDef operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  numToCollect_(
18  OperatorBase::GetSingleArgument<int>("num_to_collect", -1)) {
19  CAFFE_ENFORCE(numToCollect_ > 0);
20  }
21 
22  bool RunOnDevice() override {
23  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
24  std::lock_guard<std::mutex> guard(*mutex);
25 
26  auto* output = Output(RESERVOIR);
27  const auto& input = Input(DATA);
28 
29  CAFFE_ENFORCE_GE(input.ndim(), 1);
30 
31  bool output_initialized = output->size() > 0 &&
32  (static_cast<std::shared_ptr<std::vector<TensorCPU>>*>(
33  output->raw_mutable_data(input.meta()))[0] != nullptr);
34 
35  if (output_initialized) {
36  CAFFE_ENFORCE_EQ(output->ndim(), input.ndim());
37  for (size_t i = 1; i < input.ndim(); ++i) {
38  CAFFE_ENFORCE_EQ(output->dim(i), input.dim(i));
39  }
40  }
41 
42  auto dims = input.dims();
43  auto num_entries = dims[0];
44 
45  dims[0] = numToCollect_;
46  // IMPORTANT: Force the output to have the right type before reserving,
47  // so that the output gets the right capacity
48  output->raw_mutable_data(input.meta());
49  output->Reserve(dims, &context_);
50 
51  auto* pos_to_object =
52  OutputSize() > POS_TO_OBJECT ? Output(POS_TO_OBJECT) : nullptr;
53  if (pos_to_object) {
54  pos_to_object->Reserve(std::vector<TIndex>{numToCollect_}, &context_);
55  }
56 
57  if (num_entries == 0) {
58  if (!output_initialized) {
59  // Get both shape and meta
60  output->CopyFrom(input, &context_);
61  }
62  return true;
63  }
64 
65  const int64_t* object_id_data = nullptr;
66  std::set<int64_t> unique_object_ids;
67  if (InputSize() > OBJECT_ID) {
68  const auto& object_id = Input(OBJECT_ID);
69  CAFFE_ENFORCE_EQ(object_id.ndim(), 1);
70  CAFFE_ENFORCE_EQ(object_id.size(), num_entries);
71  object_id_data = object_id.template data<int64_t>();
72  unique_object_ids.insert(
73  object_id_data, object_id_data + object_id.size());
74  }
75 
76  const auto num_new_entries = countNewEntries(unique_object_ids);
77  auto num_to_copy = std::min<int32_t>(num_new_entries, numToCollect_);
78  auto output_batch_size = output_initialized ? output->dim(0) : 0;
79  dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
80  if (output_batch_size < numToCollect_) {
81  output->Resize(dims);
82  if (pos_to_object) {
83  pos_to_object->Resize(dims[0]);
84  }
85  }
86  auto* output_data =
87  static_cast<char*>(output->raw_mutable_data(input.meta()));
88  auto* pos_to_object_data = pos_to_object
89  ? pos_to_object->template mutable_data<int64_t>()
90  : nullptr;
91 
92  auto block_size = input.size_from_dim(1);
93  auto block_bytesize = block_size * input.itemsize();
94  const auto* input_data = static_cast<const char*>(input.raw_data());
95 
96  auto* num_visited_tensor = Output(NUM_VISITED);
97  CAFFE_ENFORCE_EQ(1, num_visited_tensor->size());
98  auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
99  if (!output_initialized) {
100  *num_visited = 0;
101  }
102  CAFFE_ENFORCE_GE(*num_visited, 0);
103 
104  const auto start_num_visited = *num_visited;
105 
106  auto* object_to_pos_map = OutputSize() > OBJECT_TO_POS_MAP
107  ? OperatorBase::Output<MapType64To32>(OBJECT_TO_POS_MAP)
108  : nullptr;
109 
110  std::set<int64_t> eligible_object_ids;
111  if (object_to_pos_map) {
112  for (auto oid : unique_object_ids) {
113  if (!object_to_pos_map->count(oid)) {
114  eligible_object_ids.insert(oid);
115  }
116  }
117  }
118 
119  for (int i = 0; i < num_entries; ++i) {
120  if (object_id_data && object_to_pos_map &&
121  !eligible_object_ids.count(object_id_data[i])) {
122  // Already in the pool or processed
123  continue;
124  }
125  if (object_id_data) {
126  eligible_object_ids.erase(object_id_data[i]);
127  }
128  int64_t pos = -1;
129  if (*num_visited < numToCollect_) {
130  // append
131  pos = *num_visited;
132  } else {
133  auto& gen = context_.RandGenerator();
134  // uniform between [0, num_visited]
135  std::uniform_int_distribution<int64_t> uniformDist(0, *num_visited);
136  pos = uniformDist(gen);
137  if (pos >= numToCollect_) {
138  // discard
139  pos = -1;
140  }
141  }
142 
143  if (pos < 0) {
144  // discard
145  CAFFE_ENFORCE_GE(*num_visited, numToCollect_);
146  } else {
147  // replace
148  context_.template CopyItems<Context, Context>(
149  input.meta(),
150  block_size,
151  input_data + i * block_bytesize,
152  output_data + pos * block_bytesize);
153 
154  if (object_id_data && pos_to_object_data && object_to_pos_map) {
155  auto old_oid = pos_to_object_data[pos];
156  auto new_oid = object_id_data[i];
157  pos_to_object_data[pos] = new_oid;
158  object_to_pos_map->erase(old_oid);
159  object_to_pos_map->emplace(new_oid, pos);
160  }
161  }
162 
163  ++(*num_visited);
164  }
165  // Sanity check
166  CAFFE_ENFORCE_EQ(*num_visited, start_num_visited + num_new_entries);
167  return true;
168  }
169 
170  private:
171  // number of tensors to collect
172  int numToCollect_;
173 
174  INPUT_TAGS(
175  RESERVOIR_IN,
176  NUM_VISITED_IN,
177  DATA,
178  MUTEX,
179  OBJECT_ID,
180  OBJECT_TO_POS_MAP_IN,
181  POS_TO_OBJECT_IN);
182  OUTPUT_TAGS(RESERVOIR, NUM_VISITED, OBJECT_TO_POS_MAP, POS_TO_OBJECT);
183 
184  int32_t countNewEntries(const std::set<int64_t>& unique_object_ids) {
185  const auto& input = Input(DATA);
186  if (InputSize() <= OBJECT_ID) {
187  return input.dim(0);
188  }
189  const auto& object_to_pos_map =
190  OperatorBase::Input<MapType64To32>(OBJECT_TO_POS_MAP_IN);
191  return std::count_if(
192  unique_object_ids.begin(),
193  unique_object_ids.end(),
194  [&object_to_pos_map](int64_t oid) {
195  return !object_to_pos_map.count(oid);
196  });
197  }
198 };
199 
200 REGISTER_CPU_OPERATOR(ReservoirSampling, ReservoirSamplingOp<CPUContext>);
201 
202 OPERATOR_SCHEMA(ReservoirSampling)
203  .NumInputs({4, 7})
204  .NumOutputs({2, 4})
205  .NumInputsOutputs([](int in, int out) { return in / 3 == out / 2; })
206  .EnforceInplace({{0, 0}, {1, 1}, {5, 2}, {6, 3}})
207  .SetDoc(R"DOC(
208 Collect `DATA` tensor into `RESERVOIR` of size `num_to_collect`. `DATA` is
209 assumed to be a batch.
210 
211 In case where 'objects' may be repeated in data and you only want at most one
212 instance of each 'object' in the reservoir, `OBJECT_ID` can be given for
213 deduplication. If `OBJECT_ID` is given, then you also need to supply additional
214 book-keeping tensors. See input blob documentation for details.
215 
216 This operator is thread-safe.
217 )DOC")
218  .Arg(
219  "num_to_collect",
220  "The number of random samples to append for each positive samples")
221  .Input(
222  0,
223  "RESERVOIR",
224  "The reservoir; should be initialized to empty tensor")
225  .Input(
226  1,
227  "NUM_VISITED",
228  "Number of examples seen so far; should be initialized to 0")
229  .Input(
230  2,
231  "DATA",
232  "Tensor to collect from. The first dimension is assumed to be batch "
233  "size. If the object to be collected is represented by multiple "
234  "tensors, use `PackRecords` to pack them into single tensor.")
235  .Input(3, "MUTEX", "Mutex to prevent data race")
236  .Input(
237  4,
238  "OBJECT_ID",
239  "(Optional, int64) If provided, used for deduplicating object in the "
240  "reservoir")
241  .Input(
242  5,
243  "OBJECT_TO_POS_MAP_IN",
244  "(Optional) Auxillary bookkeeping map. This should be created from "
245  " `CreateMap` with keys of type int64 and values of type int32")
246  .Input(
247  6,
248  "POS_TO_OBJECT_IN",
249  "(Optional) Tensor of type int64 used for bookkeeping in deduplication")
250  .Output(0, "RESERVOIR", "Same as the input")
251  .Output(1, "NUM_VISITED", "Same as the input")
252  .Output(2, "OBJECT_TO_POS_MAP", "(Optional) Same as the input")
253  .Output(3, "POS_TO_OBJECT", "(Optional) Same as the input");
254 
255 SHOULD_NOT_DO_GRADIENT(ReservoirSampling);
256 } // namespace
257 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...