4 #include "caffe2/core/operator.h" 5 #include "caffe2/core/tensor.h" 6 #include "caffe2/operators/map_ops.h" 11 template <
class Context>
12 class ReservoirSamplingOp final :
public Operator<Context> {
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 ReservoirSamplingOp(
const OperatorDef operator_def, Workspace* ws)
16 : Operator<Context>(operator_def, ws),
18 OperatorBase::GetSingleArgument<int>(
"num_to_collect", -1)) {
19 CAFFE_ENFORCE(numToCollect_ > 0);
22 bool RunOnDevice()
override {
23 auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
24 std::lock_guard<std::mutex> guard(*mutex);
26 auto* output = Output(RESERVOIR);
27 const auto& input = Input(DATA);
29 CAFFE_ENFORCE_GE(input.ndim(), 1);
31 bool output_initialized = output->size() > 0 &&
32 (
static_cast<std::shared_ptr<std::vector<TensorCPU>
>*>(
33 output->raw_mutable_data(input.meta()))[0] !=
nullptr);
35 if (output_initialized) {
36 CAFFE_ENFORCE_EQ(output->ndim(), input.ndim());
37 for (
size_t i = 1; i < input.ndim(); ++i) {
38 CAFFE_ENFORCE_EQ(output->dim(i), input.dim(i));
42 auto dims = input.dims();
43 auto num_entries = dims[0];
45 dims[0] = numToCollect_;
48 output->raw_mutable_data(input.meta());
49 output->Reserve(dims, &context_);
52 OutputSize() > POS_TO_OBJECT ? Output(POS_TO_OBJECT) : nullptr;
54 pos_to_object->Reserve(std::vector<TIndex>{numToCollect_}, &context_);
57 if (num_entries == 0) {
58 if (!output_initialized) {
60 output->CopyFrom(input, &context_);
65 const int64_t* object_id_data =
nullptr;
66 std::set<int64_t> unique_object_ids;
67 if (InputSize() > OBJECT_ID) {
68 const auto& object_id = Input(OBJECT_ID);
69 CAFFE_ENFORCE_EQ(object_id.ndim(), 1);
70 CAFFE_ENFORCE_EQ(object_id.size(), num_entries);
71 object_id_data = object_id.template data<int64_t>();
72 unique_object_ids.insert(
73 object_id_data, object_id_data + object_id.size());
76 const auto num_new_entries = countNewEntries(unique_object_ids);
77 auto num_to_copy = std::min<int32_t>(num_new_entries, numToCollect_);
78 auto output_batch_size = output_initialized ? output->dim(0) : 0;
79 dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
80 if (output_batch_size < numToCollect_) {
83 pos_to_object->Resize(dims[0]);
87 static_cast<char*
>(output->raw_mutable_data(input.meta()));
88 auto* pos_to_object_data = pos_to_object
89 ? pos_to_object->template mutable_data<int64_t>()
92 auto block_size = input.size_from_dim(1);
93 auto block_bytesize = block_size * input.itemsize();
94 const auto* input_data =
static_cast<const char*
>(input.raw_data());
96 auto* num_visited_tensor = Output(NUM_VISITED);
97 CAFFE_ENFORCE_EQ(1, num_visited_tensor->size());
98 auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
99 if (!output_initialized) {
102 CAFFE_ENFORCE_GE(*num_visited, 0);
104 const auto start_num_visited = *num_visited;
106 auto* object_to_pos_map = OutputSize() > OBJECT_TO_POS_MAP
107 ? OperatorBase::Output<MapType64To32>(OBJECT_TO_POS_MAP)
110 std::set<int64_t> eligible_object_ids;
111 if (object_to_pos_map) {
112 for (
auto oid : unique_object_ids) {
113 if (!object_to_pos_map->count(oid)) {
114 eligible_object_ids.insert(oid);
119 for (
int i = 0; i < num_entries; ++i) {
120 if (object_id_data && object_to_pos_map &&
121 !eligible_object_ids.count(object_id_data[i])) {
125 if (object_id_data) {
126 eligible_object_ids.erase(object_id_data[i]);
129 if (*num_visited < numToCollect_) {
133 auto& gen = context_.RandGenerator();
135 std::uniform_int_distribution<int64_t> uniformDist(0, *num_visited);
136 pos = uniformDist(gen);
137 if (pos >= numToCollect_) {
145 CAFFE_ENFORCE_GE(*num_visited, numToCollect_);
148 context_.template CopyItems<Context, Context>(
151 input_data + i * block_bytesize,
152 output_data + pos * block_bytesize);
154 if (object_id_data && pos_to_object_data && object_to_pos_map) {
155 auto old_oid = pos_to_object_data[pos];
156 auto new_oid = object_id_data[i];
157 pos_to_object_data[pos] = new_oid;
158 object_to_pos_map->erase(old_oid);
159 object_to_pos_map->emplace(new_oid, pos);
166 CAFFE_ENFORCE_EQ(*num_visited, start_num_visited + num_new_entries);
180 OBJECT_TO_POS_MAP_IN,
182 OUTPUT_TAGS(RESERVOIR, NUM_VISITED, OBJECT_TO_POS_MAP, POS_TO_OBJECT);
184 int32_t countNewEntries(
const std::set<int64_t>& unique_object_ids) {
185 const auto& input = Input(DATA);
186 if (InputSize() <= OBJECT_ID) {
189 const auto& object_to_pos_map =
190 OperatorBase::Input<MapType64To32>(OBJECT_TO_POS_MAP_IN);
191 return std::count_if(
192 unique_object_ids.begin(),
193 unique_object_ids.end(),
194 [&object_to_pos_map](int64_t oid) {
195 return !object_to_pos_map.count(oid);
200 REGISTER_CPU_OPERATOR(ReservoirSampling, ReservoirSamplingOp<CPUContext>);
202 OPERATOR_SCHEMA(ReservoirSampling)
205 .NumInputsOutputs([](
int in,
int out) {
return in / 3 == out / 2; })
206 .EnforceInplace({{0, 0}, {1, 1}, {5, 2}, {6, 3}})
208 Collect `DATA` tensor into `RESERVOIR` of size `num_to_collect`. `DATA` is 209 assumed to be a batch. 211 In case where 'objects' may be repeated in data and you only want at most one 212 instance of each 'object' in the reservoir, `OBJECT_ID` can be given for 213 deduplication. If `OBJECT_ID` is given, then you also need to supply additional 214 book-keeping tensors. See input blob documentation for details. 216 This operator is thread-safe. 220 "The number of random samples to append for each positive samples")
224 "The reservoir; should be initialized to empty tensor")
228 "Number of examples seen so far; should be initialized to 0")
232 "Tensor to collect from. The first dimension is assumed to be batch " 233 "size. If the object to be collected is represented by multiple " 234 "tensors, use `PackRecords` to pack them into single tensor.")
235 .Input(3,
"MUTEX",
"Mutex to prevent data race")
239 "(Optional, int64) If provided, used for deduplicating object in the " 243 "OBJECT_TO_POS_MAP_IN",
244 "(Optional) Auxillary bookkeeping map. This should be created from " 245 " `CreateMap` with keys of type int64 and values of type int32")
249 "(Optional) Tensor of type int64 used for bookkeeping in deduplication")
250 .Output(0,
"RESERVOIR",
"Same as the input")
251 .Output(1,
"NUM_VISITED",
"Same as the input")
252 .Output(2,
"OBJECT_TO_POS_MAP",
"(Optional) Same as the input")
253 .Output(3,
"POS_TO_OBJECT",
"(Optional) Same as the input");
255 SHOULD_NOT_DO_GRADIENT(ReservoirSampling);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...