4 #include "caffe2/core/operator.h" 5 #include "caffe2/core/tensor.h" 10 template <
class Context>
11 class LastNWindowCollectorOp :
public Operator<Context> {
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 LastNWindowCollectorOp(
const OperatorDef& operator_def, Workspace* ws)
15 : Operator<Context>(operator_def, ws),
17 OperatorBase::GetSingleArgument<int>(
"num_to_collect", -1)) {
18 CAFFE_ENFORCE_GT(numToCollect_, 0);
21 bool RunOnDevice()
override {
22 if (InputSize() > MUTEX) {
23 auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
24 std::lock_guard<std::mutex> guard(*mutex);
32 const int32_t numToCollect_;
35 auto* output = Output(LAST_N);
36 const auto& input = Input(DATA);
38 CAFFE_ENFORCE_GE(input.ndim(), 1);
39 bool output_initialized = output->size() > 0 &&
40 (
static_cast<std::shared_ptr<std::vector<TensorCPU>
>*>(
41 output->raw_mutable_data(input.meta()))[0] !=
nullptr);
42 if (output_initialized) {
43 CAFFE_ENFORCE_EQ(output->ndim(), input.ndim());
44 for (
size_t i = 1; i < input.ndim(); ++i) {
45 CAFFE_ENFORCE_EQ(output->dim(i), input.dim(i));
49 auto dims = input.dims();
50 auto num_entries = dims[0];
52 if (OutputSize() > NUM_VISITED) {
53 auto* num_visited_tensor = Output(NUM_VISITED);
54 CAFFE_ENFORCE_EQ(1, num_visited_tensor->size());
55 auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
56 if (!output_initialized) {
59 CAFFE_ENFORCE_GE(*num_visited, 0);
60 *num_visited += num_entries;
63 dims[0] = numToCollect_;
64 output->Reserve(dims, &context_);
66 if (num_entries == 0) {
67 if (!output_initialized) {
69 output->CopyFrom(input, &context_);
74 auto num_to_copy = std::min<int32_t>(num_entries, numToCollect_);
75 auto output_batch_size = output_initialized ? output->dim(0) : 0;
76 dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
77 if (output_batch_size < numToCollect_) {
81 static_cast<char*
>(output->raw_mutable_data(input.meta()));
83 auto* next = Output(NEXT);
84 CAFFE_ENFORCE_EQ(0, next->ndim());
85 auto* next_data = next->template mutable_data<int32_t>();
86 if (!output_initialized) {
89 CAFFE_ENFORCE_LT(*next_data, output->dim(0));
91 auto block_size = input.size_from_dim(1);
92 auto block_bytesize = block_size * input.itemsize();
93 const auto* input_data =
static_cast<const char*
>(input.raw_data());
95 if (num_entries > numToCollect_) {
97 context_.template CopyItems<Context, Context>(
99 num_to_copy * block_size,
100 input_data + (num_entries - numToCollect_) * block_bytesize,
105 auto start = *next_data;
106 auto first_chunk_size =
107 std::min<size_t>(num_to_copy + start, numToCollect_) - start;
108 context_.template CopyItems<Context, Context>(
110 first_chunk_size * block_size,
112 output_data + start * block_bytesize);
114 context_.template CopyItems<Context, Context>(
116 (num_to_copy - first_chunk_size) * block_size,
117 input_data + first_chunk_size * block_bytesize,
120 *next_data = (start + num_to_copy) % numToCollect_;
125 INPUT_TAGS(LAST_N_IN, NEXT_IN, DATA, MUTEX, NUM_VISITED_IN);
126 OUTPUT_TAGS(LAST_N, NEXT, NUM_VISITED);
129 REGISTER_CPU_OPERATOR(LastNWindowCollector, LastNWindowCollectorOp<CPUContext>);
131 OPERATOR_SCHEMA(LastNWindowCollector)
132 .NumInputs({3, 4, 5})
134 .EnforceInplace({{0, 0}, {1, 1}, {4, 2}})
136 Collect the last N rows from input data. The purpose is to keep track of data 137 accross batches, so for example suppose the LastNWindowCollector is called 138 successively with the following input data 144 And the number of items is set to 6, then the output after the 3rd call 145 will contain the following elements: 149 No guarantee is made on the ordering of elements in input. So a valid value for 150 output could have been 154 Also, this method works for any order tensor, treating the first dimension as 155 input rows and keeping the last N rows seen as input. So for instance: 157 [[1, 2], [2, 3], [3, 4], [4, 5]] 158 [[5, 6], [6, 7], [7, 8]] 159 [[8, 9], [9, 10], [10, 11], [11, 12]] 161 A possible output would be 163 [[6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12]] 165 This is not thread safe unless a mutex is given. 169 "The number of random samples to append for each positive samples")
173 "The buffer for last-N record. Should be initialized to empty tensor")
177 "The cursor pointing to the next position that should be replaced. " 178 "Should be initialized to 0.")
179 .Input(2,
"DATA",
"tensor to collect from")
180 .Input(3,
"MUTEX",
"(optional) mutex to use to make this thread-safe")
181 .Input(4,
"NUM_VISITED",
"")
182 .Output(0,
"last-N buffer",
"Data stored in sessions")
183 .Output(1,
"next cursor",
"Updated input cursor")
184 .Output(2,
"NUM_VISITED",
"number of records seen so far");
185 SHOULD_NOT_DO_GRADIENT(LastNWindowCollector);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...