Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_top_k_op.cc
1 #include "caffe2/operators/lengths_top_k_op.h"
2 
3 namespace caffe2 {
4 
5 template <typename T, class Context>
6 bool LengthsTopKOp<T, Context>::RunOnDevice() {
7  auto& X = Input(X_IN);
8  auto& Y = Input(Y_IN);
9  int N = Y.dim32(0);
10  const T* X_data = X.template data<T>();
11  const int* input_len = Y.template data<int>();
12  auto* output_topk_values = Output(TOPK_VALUES_OUT);
13  auto* output_topk_indices = Output(TOPK_INDICES_OUT);
14 
15  output_topk_values->Resize(N * k_);
16  output_topk_indices->Resize(N * k_);
17  std::vector<int> output_dims = std::vector<int>({N, k_});
18  output_topk_values->Reshape(output_dims);
19  output_topk_indices->Reshape(output_dims);
20  T* output_topk_values_data = output_topk_values->template mutable_data<T>();
21  int* output_topk_indices_data =
22  output_topk_indices->template mutable_data<int>();
23 
24  auto cmp = [](std::pair<T, TIndex>& lhs, std::pair<T, TIndex>& rhs) {
25  return lhs.first > rhs.first ||
26  (lhs.first == rhs.first && lhs.second < rhs.second);
27  };
28 
29  // Sort preserving indices
30  int next_index = 0;
31  for (TIndex i = 0; i < N; ++i) {
32  // Build a min-heap, the heap element is pair of (value, idx)
33  // the top of the heap is the smallest value
34  std::priority_queue<
35  std::pair<T, TIndex>,
36  std::vector<std::pair<T, TIndex>>,
37  decltype(cmp)>
38  p_queue(cmp);
39 
40  // Maintain the size of heap to be less or equal to k_, so the
41  // heap will hold the k_ largest values
42  for (TIndex j = 0; j < input_len[i]; ++j) {
43  const auto value = X_data[next_index++];
44  if (p_queue.size() < k_ || value > p_queue.top().first) {
45  p_queue.push(std::make_pair(value, j));
46  }
47  if (p_queue.size() > k_) {
48  p_queue.pop();
49  }
50  }
51 
52  int last_index = p_queue.size();
53  for (TIndex j = 0; j < k_; ++j) {
54  if (p_queue.size() > 0) {
55  auto& pqElem = p_queue.top();
56  output_topk_values_data[i * k_ + last_index - j - 1] = pqElem.first;
57  output_topk_indices_data[i * k_ + last_index - j - 1] = pqElem.second;
58  p_queue.pop();
59  } else {
60  output_topk_values_data[i * k_ + j] = 0;
61  output_topk_indices_data[i * k_ + j] = -1;
62  }
63  }
64  }
65 
66  return true;
67 }
68 
69 template <typename T, class Context>
70 bool LengthsTopKGradientOp<T, Context>::RunOnDevice() {
71  auto& input_len = Input(LENGTH_IN);
72  int N = input_len.size();
73  auto& input_indices = Input(INDICES_IN);
74  CAFFE_ENFORCE_GE(input_indices.ndim(), 2, "input dim must be >= 2");
75  CAFFE_ENFORCE_EQ(
76  input_indices.size(), N * k_, "input_indices shape is not correct");
77  auto& input_topk = Input(DER_TOPK_IN);
78  CAFFE_ENFORCE_EQ(
79  input_topk.size(), N * k_, "input_topk shape is not correct");
80  auto* X_out = Output(DER_X_OUT);
81 
82  const int* input_len_data = input_len.template data<int>();
83  const int* input_indices_data = input_indices.template data<int>();
84  const T* input_topk_data = input_topk.template data<T>();
85 
86  int num_indices = 0;
87  for (int i = 0; i < N; i++) {
88  num_indices += input_len_data[i];
89  }
90  X_out->Resize(num_indices);
91  std::vector<int> output_dims = std::vector<int>({num_indices});
92  X_out->Reshape(output_dims);
93  T* X_out_data = X_out->template mutable_data<T>();
94  math::Set<T, Context>(num_indices, 0.0, X_out_data, &context_);
95 
96  int index_offset = 0;
97  for (int i = 0; i < N; i++) {
98  for (int j = 0; j < std::min(input_len_data[i], k_); j++) {
99  int cur_index = index_offset + input_indices_data[i * k_ + j];
100  CAFFE_ENFORCE_LT(
101  cur_index, num_indices, "cur_index should be less than num_indices");
102  X_out_data[cur_index] = input_topk_data[i * k_ + j];
103  }
104  index_offset += input_len_data[i];
105  }
106 
107  return true;
108 }
109 
110 REGISTER_CPU_OPERATOR(LengthsTopK, LengthsTopKOp<float, CPUContext>);
111 REGISTER_CPU_OPERATOR(
112  LengthsTopKGradient,
113  LengthsTopKGradientOp<float, CPUContext>);
114 OPERATOR_SCHEMA(LengthsTopK)
115  .NumInputs(2)
116  .NumOutputs(2)
117  .SetDoc(R"DOC(
118 Apply TopK to each segment of the input tensor, where segments are defined by
119 their LENGTHS, and concatenate them in an output tensor of
120 shape=(SIZE(LENGTHs), k). In case there's less than k values in a segment,
121 the output value will be padded by 0, and the corresponding output indices will
122 be padded by -1.
123 )DOC")
124  .Input(
125  0,
126  "DATA",
127  "Tensor of rank 1. First dimension must be equal to the sum of "
128  "lengths")
129  .Input(1, "LENGTHS", "Tensor of int32 lengths of rank 1")
130  .Output(
131  0,
132  "TopKValue",
133  "Output top k elements for each segment, with"
134  "shape=(SIZE(lengths), k)")
135  .Output(
136  1,
137  "TopKIndices",
138  "Output indices in DATA corresponding to value in TopKValue")
139  .Arg(
140  "k",
141  "the number of top values to return for each segment, if the number "
142  "of values is smaller than k, the values would be padded with 0 and "
143  "indices would be padded with -1.");
144 OPERATOR_SCHEMA(LengthsTopKGradient).NumInputs(3).NumOutputs(1);
145 
146 namespace {
147 
148 class GetLengthsTopKGradient : public GradientMakerBase {
149  using GradientMakerBase::GradientMakerBase;
150  vector<OperatorDef> GetGradientDefs() override {
151  return SingleGradientDef(
152  "LengthsTopKGradient",
153  "",
154  vector<string>{I(1), O(1), GO(0)},
155  vector<string>{GI(0)});
156  }
157 };
158 
159 } // namespace
160 
161 REGISTER_GRADIENT(LengthsTopK, GetLengthsTopKGradient);
162 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...