Caffe2 - C++ API
A deep learning, cross platform ML framework
top_k.cc
1 #include "caffe2/operators/top_k.h"
2 
3 #include <algorithm>
4 #include <functional>
5 #include <queue>
6 #include <utility>
7 #include <vector>
8 
9 #include "caffe2/proto/caffe2.pb.h"
10 #include "caffe2/utils/math.h"
11 
12 namespace caffe2 {
13 
14 namespace {
15 
16 template <typename T>
17 struct ValueComp {
18  bool operator()(
19  const std::pair<T, TIndex>& lhs,
20  const std::pair<T, TIndex>& rhs) const {
21  return lhs.first > rhs.first ||
22  (lhs.first == rhs.first && lhs.second < rhs.second);
23  }
24 };
25 
26 template <typename T>
27 void GetTopK(
28  const T* input,
29  const TIndex n,
30  const TIndex k,
31  const TIndex src_offset,
32  const TIndex dst_offset,
33  const TIndex stride,
34  T* values,
35  TIndex* indices,
36  TIndex* flatten_indices) {
37  const T* src_ptr = input + src_offset;
38  std::vector<std::pair<T, TIndex>> heap_data;
39  heap_data.reserve(k);
40  for (TIndex i = 0; i < k; ++i) {
41  heap_data.emplace_back(*src_ptr, i);
42  src_ptr += stride;
43  }
44  std::priority_queue<
45  std::pair<T, TIndex>,
46  std::vector<std::pair<T, TIndex>>,
47  ValueComp<T>>
48  pq(ValueComp<T>(), std::move(heap_data));
49  for (TIndex i = k; i < n; ++i) {
50  if (pq.top().first < *src_ptr) {
51  pq.pop();
52  pq.emplace(*src_ptr, i);
53  }
54  src_ptr += stride;
55  }
56  TIndex dst_pos = dst_offset + (k - 1) * stride;
57  while (!pq.empty()) {
58  const auto& item = pq.top();
59  values[dst_pos] = item.first;
60  indices[dst_pos] = item.second;
61  if (flatten_indices != nullptr) {
62  flatten_indices[dst_pos] = src_offset + item.second * stride;
63  }
64  pq.pop();
65  dst_pos -= stride;
66  }
67 }
68 
69 template <typename T>
70 void SetTopKGradient(
71  const T* values,
72  const TIndex* indices,
73  const int k,
74  const TIndex src_offset,
75  const TIndex dst_offset,
76  const TIndex stride,
77  T* gradient) {
78  TIndex src_pos = src_offset;
79  for (int i = 0; i < k; ++i) {
80  gradient[dst_offset + indices[src_pos] * stride] = values[src_pos];
81  src_pos += stride;
82  }
83 }
84 
85 } // namespace
86 
87 template <typename T, class Context>
88 bool TopKOp<T, Context>::RunOnDevice() {
89  const auto& input = Input(0);
90  auto* values = Output(0);
91  auto* indices = Output(1);
92  auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
93 
94  const std::vector<TIndex>& input_dims = input.dims();
95  if (axis_ == -1) {
96  axis_ = input_dims.size() - 1;
97  }
98  CAFFE_ENFORCE_GE(axis_, 0);
99  CAFFE_ENFORCE_LT(axis_, input_dims.size());
100  CAFFE_ENFORCE_LE(
101  k_,
102  input_dims[axis_],
103  "k argument should not be greater than the axis dim.");
104 
105  std::vector<TIndex> output_dims = input_dims;
106  output_dims[axis_] = k_;
107  values->Resize(output_dims);
108  indices->Resize(output_dims);
109  if (flatten_indices != nullptr) {
110  flatten_indices->Resize(indices->size());
111  }
112  const T* input_data = input.template data<T>();
113  T* values_data = values->template mutable_data<T>();
114  TIndex* indices_data = indices->template mutable_data<TIndex>();
115  TIndex* flatten_indices_data = flatten_indices == nullptr
116  ? nullptr
117  : flatten_indices->template mutable_data<TIndex>();
118 
119  const TIndex prev_size = std::accumulate(
120  input_dims.cbegin(),
121  input_dims.cbegin() + axis_,
122  TIndex(1),
123  std::multiplies<TIndex>());
124  const TIndex next_size = std::accumulate(
125  input_dims.cbegin() + axis_ + 1,
126  input_dims.cend(),
127  TIndex(1),
128  std::multiplies<TIndex>());
129  const TIndex src_offset_stride = input_dims[axis_] * next_size;
130  const TIndex dst_offset_stride = k_ * next_size;
131  TIndex src_offset = 0;
132  TIndex dst_offset = 0;
133  for (TIndex i = 0; i < prev_size; ++i) {
134  for (TIndex j = 0; j < next_size; ++j) {
135  GetTopK(
136  input_data,
137  input_dims[axis_],
138  k_,
139  src_offset + j,
140  dst_offset + j,
141  next_size,
142  values_data,
143  indices_data,
144  flatten_indices_data);
145  }
146  src_offset += src_offset_stride;
147  dst_offset += dst_offset_stride;
148  }
149  return true;
150 }
151 
152 template <typename T, class Context>
153 bool TopKGradientOp<T, Context>::RunOnDevice() {
154  const auto& values = Input(0);
155  const auto& indices = Input(1);
156  const auto& original_input = Input(2);
157  auto* output = Output(0);
158  const std::vector<TIndex>& values_dims = values.dims();
159  const std::vector<TIndex>& origin_dims = original_input.dims();
160  CAFFE_ENFORCE_EQ(values_dims.size(), origin_dims.size());
161  output->Resize(origin_dims);
162  const T* values_data = values.template data<T>();
163  const TIndex* indices_data = indices.template data<TIndex>();
164  T* output_data = output->template mutable_data<T>();
165  if (axis_ == -1) {
166  axis_ = values_dims.size() - 1;
167  }
168  const int k = values_dims[axis_];
169  math::Set<T, Context>(output->size(), T(0), output_data, &context_);
170  const TIndex prev_size = std::accumulate(
171  values_dims.cbegin(),
172  values_dims.cbegin() + axis_,
173  TIndex(1),
174  std::multiplies<TIndex>());
175  const TIndex next_size = std::accumulate(
176  values_dims.cbegin() + axis_ + 1,
177  values_dims.cend(),
178  TIndex(1),
179  std::multiplies<TIndex>());
180  const TIndex src_offset_stride = k * next_size;
181  const TIndex dst_offset_stride = origin_dims[axis_] * next_size;
182  TIndex src_offset = 0;
183  TIndex dst_offset = 0;
184  for (TIndex i = 0; i < prev_size; ++i) {
185  for (TIndex j = 0; j < next_size; ++j) {
186  SetTopKGradient(
187  values_data,
188  indices_data,
189  k,
190  src_offset + j,
191  dst_offset + j,
192  next_size,
193  output_data);
194  }
195  src_offset += src_offset_stride;
196  dst_offset += dst_offset_stride;
197  }
198  return true;
199 }
200 
201 REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>);
202 REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>);
203 
204 OPERATOR_SCHEMA(TopK)
205  .NumInputs(1)
206  .NumOutputs(2, 3)
207  .TensorInferenceFunction([](const OperatorDef& def,
208  const vector<TensorShape>& in) {
209  vector<TensorShape> out = {in[0], in[0]};
210  ArgumentHelper helper(def);
211  auto k = helper.GetSingleArgument("k", -1);
212  auto dims_size = in[0].dims_size();
213  out[0].set_dims(dims_size - 1, k);
214  out[1].set_dims(dims_size - 1, k);
215  out[1].set_data_type(TensorProto_DataType_INT32);
216  if (def.output_size() > 2) {
217  TensorShape flatten_indices_shape;
218  flatten_indices_shape.set_data_type(TensorProto_DataType_INT32);
219  flatten_indices_shape.add_dims(
220  std::accumulate(
221  in[0].dims().begin(),
222  in[0].dims().end() - 1,
223  1,
224  std::multiplies<long>()) *
225  k);
226  out.push_back(flatten_indices_shape);
227  }
228  return out;
229  })
230  .SetDoc(R"DOC(
231 Retrieve the top-K elements for the last dimension. Given an input tensor of
232 shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs:
233 -Value tensor of shape [a_1, a_2, ..., a_n, k] which contains the values of
234  the top k elements along the last dimension
235 -Index tensor of shape [a_1, a_2, ..., a_n, k] which contains the indices
236  of the top k elements (original indices from the input tensor).
237 
238 Given two equivalent values, this operator uses the indices along the last dim-
239 ension as a tiebreaker. That is, the element with the lower index will appear
240 first.
241  )DOC")
242  .Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]")
243  .Output(
244  0,
245  "Values",
246  "Tensor of shape [a_1, a_2, ..., a_n, k] containing"
247  " top K values from the input tensor")
248  .Output(
249  1,
250  "Indices",
251  "Tensor of shape [a_1, a_2, ..., a_n, k] containing"
252  " the corresponding input tensor indices for the top K values.")
253  .Output(
254  2,
255  "Flatten indices",
256  "Tensor of shape [a_1 * a_2 * ... * a_n * k] containing the indices "
257  "into the flatten input")
258  .Arg("k", "Number of top elements to retrieve");
259 
260 OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1);
261 
263  using GradientMakerBase::GradientMakerBase;
264  vector<OperatorDef> GetGradientDefs() override {
265  return SingleGradientDef(
266  "TopKGradient",
267  "",
268  vector<string>{GO(0), O(1), I(0)},
269  vector<string>{GI(0)});
270  }
271 };
272 
273 REGISTER_GRADIENT(TopK, GetTopKGradient);
274 
275 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...