1 #include "caffe2/operators/lengths_top_k_op.h" 5 template <
typename T,
class Context>
6 bool LengthsTopKOp<T, Context>::RunOnDevice() {
10 const T* X_data = X.template data<T>();
11 const int* input_len = Y.template data<int>();
12 auto* output_topk_values = Output(TOPK_VALUES_OUT);
13 auto* output_topk_indices = Output(TOPK_INDICES_OUT);
15 output_topk_values->Resize(N * k_);
16 output_topk_indices->Resize(N * k_);
17 std::vector<int> output_dims = std::vector<int>({N, k_});
18 output_topk_values->Reshape(output_dims);
19 output_topk_indices->Reshape(output_dims);
20 T* output_topk_values_data = output_topk_values->template mutable_data<T>();
21 int* output_topk_indices_data =
22 output_topk_indices->template mutable_data<int>();
24 auto cmp = [](std::pair<T, TIndex>& lhs, std::pair<T, TIndex>& rhs) {
25 return lhs.first > rhs.first ||
26 (lhs.first == rhs.first && lhs.second < rhs.second);
31 for (TIndex i = 0; i < N; ++i) {
36 std::vector<std::pair<T, TIndex>>,
42 for (TIndex j = 0; j < input_len[i]; ++j) {
43 const auto value = X_data[next_index++];
44 if (p_queue.size() < k_ || value > p_queue.top().first) {
45 p_queue.push(std::make_pair(value, j));
47 if (p_queue.size() > k_) {
52 int last_index = p_queue.size();
53 for (TIndex j = 0; j < k_; ++j) {
54 if (p_queue.size() > 0) {
55 auto& pqElem = p_queue.top();
56 output_topk_values_data[i * k_ + last_index - j - 1] = pqElem.first;
57 output_topk_indices_data[i * k_ + last_index - j - 1] = pqElem.second;
60 output_topk_values_data[i * k_ + j] = 0;
61 output_topk_indices_data[i * k_ + j] = -1;
69 template <
typename T,
class Context>
70 bool LengthsTopKGradientOp<T, Context>::RunOnDevice() {
71 auto& input_len = Input(LENGTH_IN);
72 int N = input_len.size();
73 auto& input_indices = Input(INDICES_IN);
74 CAFFE_ENFORCE_GE(input_indices.ndim(), 2,
"input dim must be >= 2");
76 input_indices.size(), N * k_,
"input_indices shape is not correct");
77 auto& input_topk = Input(DER_TOPK_IN);
79 input_topk.size(), N * k_,
"input_topk shape is not correct");
80 auto* X_out = Output(DER_X_OUT);
82 const int* input_len_data = input_len.template data<int>();
83 const int* input_indices_data = input_indices.template data<int>();
84 const T* input_topk_data = input_topk.template data<T>();
87 for (
int i = 0; i < N; i++) {
88 num_indices += input_len_data[i];
90 X_out->Resize(num_indices);
91 std::vector<int> output_dims = std::vector<int>({num_indices});
92 X_out->Reshape(output_dims);
93 T* X_out_data = X_out->template mutable_data<T>();
94 math::Set<T, Context>(num_indices, 0.0, X_out_data, &context_);
97 for (
int i = 0; i < N; i++) {
98 for (
int j = 0; j < std::min(input_len_data[i], k_); j++) {
99 int cur_index = index_offset + input_indices_data[i * k_ + j];
101 cur_index, num_indices,
"cur_index should be less than num_indices");
102 X_out_data[cur_index] = input_topk_data[i * k_ + j];
104 index_offset += input_len_data[i];
110 REGISTER_CPU_OPERATOR(LengthsTopK, LengthsTopKOp<float, CPUContext>);
111 REGISTER_CPU_OPERATOR(
113 LengthsTopKGradientOp<float, CPUContext>);
114 OPERATOR_SCHEMA(LengthsTopK)
118 Apply TopK to each segment of the input tensor, where segments are defined by 119 their LENGTHS, and concatenate them in an output tensor of 120 shape=(SIZE(LENGTHs), k). In case there's less than k values in a segment, 121 the output value will be padded by 0, and the corresponding output indices will 127 "Tensor of rank 1. First dimension must be equal to the sum of " 129 .Input(1,
"LENGTHS",
"Tensor of int32 lengths of rank 1")
133 "Output top k elements for each segment, with" 134 "shape=(SIZE(lengths), k)")
138 "Output indices in DATA corresponding to value in TopKValue")
141 "the number of top values to return for each segment, if the number " 142 "of values is smaller than k, the values would be padded with 0 and " 143 "indices would be padded with -1.");
144 OPERATOR_SCHEMA(LengthsTopKGradient).NumInputs(3).NumOutputs(1);
148 class GetLengthsTopKGradient :
public GradientMakerBase {
149 using GradientMakerBase::GradientMakerBase;
150 vector<OperatorDef> GetGradientDefs()
override {
151 return SingleGradientDef(
152 "LengthsTopKGradient",
154 vector<string>{I(1), O(1), GO(0)},
155 vector<string>{GI(0)});
161 REGISTER_GRADIENT(LengthsTopK, GetLengthsTopKGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...