Caffe2 - C++ API
A deep learning, cross platform ML framework
last_n_window_collector.cc
1 #include <memory>
2 #include <string>
3 #include <vector>
4 #include "caffe2/core/operator.h"
5 #include "caffe2/core/tensor.h"
6 
7 namespace caffe2 {
8 namespace {
9 
10 template <class Context>
11 class LastNWindowCollectorOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  LastNWindowCollectorOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  numToCollect_(
17  OperatorBase::GetSingleArgument<int>("num_to_collect", -1)) {
18  CAFFE_ENFORCE_GT(numToCollect_, 0);
19  }
20 
21  bool RunOnDevice() override {
22  if (InputSize() > MUTEX) {
23  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
24  std::lock_guard<std::mutex> guard(*mutex);
25  return collect();
26  } else {
27  return collect();
28  }
29  }
30 
31  private:
32  const int32_t numToCollect_;
33 
34  bool collect() {
35  auto* output = Output(LAST_N);
36  const auto& input = Input(DATA);
37 
38  CAFFE_ENFORCE_GE(input.ndim(), 1);
39  bool output_initialized = output->size() > 0 &&
40  (static_cast<std::shared_ptr<std::vector<TensorCPU>>*>(
41  output->raw_mutable_data(input.meta()))[0] != nullptr);
42  if (output_initialized) {
43  CAFFE_ENFORCE_EQ(output->ndim(), input.ndim());
44  for (size_t i = 1; i < input.ndim(); ++i) {
45  CAFFE_ENFORCE_EQ(output->dim(i), input.dim(i));
46  }
47  }
48 
49  auto dims = input.dims();
50  auto num_entries = dims[0];
51 
52  if (OutputSize() > NUM_VISITED) {
53  auto* num_visited_tensor = Output(NUM_VISITED);
54  CAFFE_ENFORCE_EQ(1, num_visited_tensor->size());
55  auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
56  if (!output_initialized) {
57  *num_visited = 0;
58  }
59  CAFFE_ENFORCE_GE(*num_visited, 0);
60  *num_visited += num_entries;
61  }
62 
63  dims[0] = numToCollect_;
64  output->Reserve(dims, &context_);
65 
66  if (num_entries == 0) {
67  if (!output_initialized) {
68  // Get both shape and meta
69  output->CopyFrom(input, &context_);
70  }
71  return true;
72  }
73 
74  auto num_to_copy = std::min<int32_t>(num_entries, numToCollect_);
75  auto output_batch_size = output_initialized ? output->dim(0) : 0;
76  dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
77  if (output_batch_size < numToCollect_) {
78  output->Resize(dims);
79  }
80  auto* output_data =
81  static_cast<char*>(output->raw_mutable_data(input.meta()));
82 
83  auto* next = Output(NEXT);
84  CAFFE_ENFORCE_EQ(0, next->ndim());
85  auto* next_data = next->template mutable_data<int32_t>();
86  if (!output_initialized) {
87  *next_data = 0;
88  }
89  CAFFE_ENFORCE_LT(*next_data, output->dim(0));
90 
91  auto block_size = input.size_from_dim(1);
92  auto block_bytesize = block_size * input.itemsize();
93  const auto* input_data = static_cast<const char*>(input.raw_data());
94 
95  if (num_entries > numToCollect_) {
96  // just copy the last N rows
97  context_.template CopyItems<Context, Context>(
98  input.meta(),
99  num_to_copy * block_size,
100  input_data + (num_entries - numToCollect_) * block_bytesize,
101  output_data);
102  *next_data = 0;
103  return true;
104  }
105  auto start = *next_data;
106  auto first_chunk_size =
107  std::min<size_t>(num_to_copy + start, numToCollect_) - start;
108  context_.template CopyItems<Context, Context>(
109  input.meta(),
110  first_chunk_size * block_size,
111  input_data,
112  output_data + start * block_bytesize);
113 
114  context_.template CopyItems<Context, Context>(
115  input.meta(),
116  (num_to_copy - first_chunk_size) * block_size,
117  input_data + first_chunk_size * block_bytesize,
118  output_data);
119 
120  *next_data = (start + num_to_copy) % numToCollect_;
121 
122  return true;
123  }
124 
125  INPUT_TAGS(LAST_N_IN, NEXT_IN, DATA, MUTEX, NUM_VISITED_IN);
126  OUTPUT_TAGS(LAST_N, NEXT, NUM_VISITED);
127 };
128 
129 REGISTER_CPU_OPERATOR(LastNWindowCollector, LastNWindowCollectorOp<CPUContext>);
130 
131 OPERATOR_SCHEMA(LastNWindowCollector)
132  .NumInputs({3, 4, 5})
133  .NumOutputs(2, 3)
134  .EnforceInplace({{0, 0}, {1, 1}, {4, 2}})
135  .SetDoc(R"DOC(
136 Collect the last N rows from input data. The purpose is to keep track of data
137 accross batches, so for example suppose the LastNWindowCollector is called
138 successively with the following input data
139 
140  [1, 2, 3, 4]
141  [5, 6, 7]
142  [8, 9, 10, 11]
143 
144 And the number of items is set to 6, then the output after the 3rd call
145 will contain the following elements:
146 
147  [6, 7, 8, 9, 10, 11]
148 
149 No guarantee is made on the ordering of elements in input. So a valid value for
150 output could have been
151 
152  [11, 10, 9, 8, 7, 6]
153 
154 Also, this method works for any order tensor, treating the first dimension as
155 input rows and keeping the last N rows seen as input. So for instance:
156 
157  [[1, 2], [2, 3], [3, 4], [4, 5]]
158  [[5, 6], [6, 7], [7, 8]]
159  [[8, 9], [9, 10], [10, 11], [11, 12]]
160 
161 A possible output would be
162 
163  [[6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12]]
164 
165 This is not thread safe unless a mutex is given.
166 )DOC")
167  .Arg(
168  "num_to_collect",
169  "The number of random samples to append for each positive samples")
170  .Input(
171  0,
172  "last-N buffer",
173  "The buffer for last-N record. Should be initialized to empty tensor")
174  .Input(
175  1,
176  "next cursor",
177  "The cursor pointing to the next position that should be replaced. "
178  "Should be initialized to 0.")
179  .Input(2, "DATA", "tensor to collect from")
180  .Input(3, "MUTEX", "(optional) mutex to use to make this thread-safe")
181  .Input(4, "NUM_VISITED", "")
182  .Output(0, "last-N buffer", "Data stored in sessions")
183  .Output(1, "next cursor", "Updated input cursor")
184  .Output(2, "NUM_VISITED", "number of records seen so far");
185 SHOULD_NOT_DO_GRADIENT(LastNWindowCollector);
186 } // namespace
187 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...