1 #include "caffe2/operators/flexible_top_k.h" 3 #include "caffe2/proto/caffe2.pb.h" 12 const std::pair<T, TIndex>& lhs,
13 const std::pair<T, TIndex>& rhs) {
15 lhs.first > rhs.first ||
16 (lhs.first == rhs.first && lhs.second < rhs.second));
22 template <
typename T,
class Context>
23 bool FlexibleTopKOp<T, Context>::RunOnDevice() {
24 auto& input = Input(0);
26 auto* values = Output(0);
27 auto* indices = Output(1);
29 const T* input_data = input.template data<T>();
30 const TIndex* k_data = k.template data<TIndex>();
33 CAFFE_ENFORCE_GT(input.ndim(), 0);
34 vector<TIndex> input_dims = input.dims();
35 vector<TIndex> linear_shape = {
36 size_to_dim_(input_dims.size() - 1, input_dims), input_dims.back()};
40 "first n-1 dims of input data and K does not match.");
42 TIndex output_size = 0;
43 for (TIndex i = 0; i < linear_shape[0]; ++i) {
45 linear_shape[1] >= k_data[i],
46 "k should not be greater than last dim, error at index ",
52 "k should be greater than 0, error at index ",
56 output_size += k_data[i];
58 values->Resize(output_size);
59 indices->Resize(output_size);
60 T* values_data = values->template mutable_data<T>();
61 TIndex* indices_data = indices->template mutable_data<TIndex>();
63 TIndex output_offset = 0;
65 for (TIndex i = 0; i < linear_shape[0]; ++i) {
70 std::vector<std::pair<T, TIndex>>,
74 TIndex k_ = k_data[i];
75 for (TIndex j = 0; j < linear_shape[1]; ++j) {
76 const T value = input_data[i * linear_shape[1] + j];
77 if (PQ.size() < k_ || value > PQ.top().first) {
78 PQ.push(std::make_pair(value, j));
84 for (TIndex j = 0; j < k_; ++j) {
85 auto& pqElem = PQ.top();
86 values_data[output_offset + k_ - j - 1] = pqElem.first;
87 indices_data[output_offset + k_ - j - 1] = pqElem.second;
96 template <
typename T,
class Context>
97 bool FlexibleTopKGradientOp<T, Context>::RunOnDevice() {
98 auto& original_input = Input(0);
100 auto& values = Input(2);
101 auto& indices = Input(3);
102 auto* output = Output(0);
104 const TIndex* k_data = k.template data<TIndex>();
105 const T* values_data = values.template data<T>();
106 const TIndex* indices_data = indices.template data<TIndex>();
109 CAFFE_ENFORCE_GT(original_input.ndim(), 0);
110 vector<TIndex> original_dims = original_input.dims();
111 output->Resize(original_dims);
112 T* output_data = output->template mutable_data<T>();
113 math::Set<T, Context>(
114 output->size(),
static_cast<T
>(0), output_data, &context_);
116 TIndex index_offset = 0;
117 for (TIndex i = 0; i < k.size(); ++i) {
119 TIndex output_offset = i * original_dims.back();
120 for (TIndex j = 0; j < k_data[i]; ++j) {
121 TIndex index = indices_data[index_offset + j];
122 T value = values_data[index_offset + j];
123 output_data[output_offset + index] = value;
125 index_offset += k_data[i];
131 REGISTER_CPU_OPERATOR(FlexibleTopK, FlexibleTopKOp<float, CPUContext>);
132 REGISTER_CPU_OPERATOR(
133 FlexibleTopKGradient,
134 FlexibleTopKGradientOp<float, CPUContext>);
136 OPERATOR_SCHEMA(FlexibleTopK)
140 Given two tensors: X and K, 141 retrieve the top K[..., 1] elements from X on the last dimension. 142 X is an input tensor of shape [a_1, a_2, ..., a_n, r]. 143 K is an input tensor of shape [a_1, a_2, ..., a_n, 1], 144 where for each element, r >= K[..., 1] > 0 146 -Flatten values tensor of shape [ \sum_i K[i, 1] ] which contains the values of 147 the top K[..., 1] elements along the last dimension 148 -Flatten indices tensor of shape [ \sum_i K[i, 1] ] which contains the indices 149 of the top K[..., 1] elements, flatten indices from the input tensor). 150 These two outputs should be used with the input K, so that we know which indices 153 Given two equivalent values, this operator uses the indices along the last dim- 154 ension as a tiebreaker. That is, the element with the lower index will appear 157 .Input(0, "X",
"Tensor of shape [a_1, a_2, ..., a_n, r]")
158 .Input(1,
"K",
"Tensor of shape [a_1, a_2, ..., a_n, 1]")
162 "Tensor of shape [ \\sum_i K[i, 1] ] containing" 163 " top K[..., 1] values from the input tensor")
167 "Tensor of shape [ \\sum_i K[i, 1] ] containing the indices " 168 "into the flatten input");
170 OPERATOR_SCHEMA(FlexibleTopKGradient).NumInputs(4).NumOutputs(1);
172 class GetFlexibleTopKGradient :
public GradientMakerBase {
173 using GradientMakerBase::GradientMakerBase;
174 vector<OperatorDef> GetGradientDefs()
override {
175 return SingleGradientDef(
176 "FlexibleTopKGradient",
178 vector<string>{I(0), I(1), GO(0), O(1)},
179 vector<string>{GI(0)});
183 REGISTER_GRADIENT(FlexibleTopK, GetFlexibleTopKGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...