1 #include "caffe2/operators/top_k.h" 9 #include "caffe2/proto/caffe2.pb.h" 10 #include "caffe2/utils/math.h" 19 const std::pair<T, TIndex>& lhs,
20 const std::pair<T, TIndex>& rhs)
const {
21 return lhs.
first > rhs.first ||
22 (lhs.first == rhs.first && lhs.second < rhs.second);
31 const TIndex src_offset,
32 const TIndex dst_offset,
36 TIndex* flatten_indices) {
37 const T* src_ptr = input + src_offset;
38 std::vector<std::pair<T, TIndex>> heap_data;
40 for (TIndex i = 0; i < k; ++i) {
41 heap_data.emplace_back(*src_ptr, i);
46 std::vector<std::pair<T, TIndex>>,
48 pq(ValueComp<T>(), std::move(heap_data));
49 for (TIndex i = k; i < n; ++i) {
50 if (pq.top().first < *src_ptr) {
52 pq.emplace(*src_ptr, i);
56 TIndex dst_pos = dst_offset + (k - 1) * stride;
58 const auto& item = pq.top();
59 values[dst_pos] = item.first;
60 indices[dst_pos] = item.second;
61 if (flatten_indices !=
nullptr) {
62 flatten_indices[dst_pos] = src_offset + item.second * stride;
72 const TIndex* indices,
74 const TIndex src_offset,
75 const TIndex dst_offset,
78 TIndex src_pos = src_offset;
79 for (
int i = 0; i < k; ++i) {
80 gradient[dst_offset + indices[src_pos] * stride] = values[src_pos];
87 template <
typename T,
class Context>
88 bool TopKOp<T, Context>::RunOnDevice() {
89 const auto& input = Input(0);
90 auto* values = Output(0);
91 auto* indices = Output(1);
92 auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
94 const std::vector<TIndex>& input_dims = input.dims();
96 axis_ = input_dims.size() - 1;
98 CAFFE_ENFORCE_GE(axis_, 0);
99 CAFFE_ENFORCE_LT(axis_, input_dims.size());
103 "k argument should not be greater than the axis dim.");
105 std::vector<TIndex> output_dims = input_dims;
106 output_dims[axis_] = k_;
107 values->Resize(output_dims);
108 indices->Resize(output_dims);
109 if (flatten_indices !=
nullptr) {
110 flatten_indices->Resize(indices->size());
112 const T* input_data = input.template data<T>();
113 T* values_data = values->template mutable_data<T>();
114 TIndex* indices_data = indices->template mutable_data<TIndex>();
115 TIndex* flatten_indices_data = flatten_indices ==
nullptr 117 : flatten_indices->template mutable_data<TIndex>();
119 const TIndex prev_size = std::accumulate(
121 input_dims.cbegin() + axis_,
123 std::multiplies<TIndex>());
124 const TIndex next_size = std::accumulate(
125 input_dims.cbegin() + axis_ + 1,
128 std::multiplies<TIndex>());
129 const TIndex src_offset_stride = input_dims[axis_] * next_size;
130 const TIndex dst_offset_stride = k_ * next_size;
131 TIndex src_offset = 0;
132 TIndex dst_offset = 0;
133 for (TIndex i = 0; i < prev_size; ++i) {
134 for (TIndex j = 0; j < next_size; ++j) {
144 flatten_indices_data);
146 src_offset += src_offset_stride;
147 dst_offset += dst_offset_stride;
152 template <
typename T,
class Context>
153 bool TopKGradientOp<T, Context>::RunOnDevice() {
154 const auto& values = Input(0);
155 const auto& indices = Input(1);
156 const auto& original_input = Input(2);
157 auto* output = Output(0);
158 const std::vector<TIndex>& values_dims = values.dims();
159 const std::vector<TIndex>& origin_dims = original_input.dims();
160 CAFFE_ENFORCE_EQ(values_dims.size(), origin_dims.size());
161 output->Resize(origin_dims);
162 const T* values_data = values.template data<T>();
163 const TIndex* indices_data = indices.template data<TIndex>();
164 T* output_data = output->template mutable_data<T>();
166 axis_ = values_dims.size() - 1;
168 const int k = values_dims[axis_];
169 math::Set<T, Context>(output->size(), T(0), output_data, &context_);
170 const TIndex prev_size = std::accumulate(
171 values_dims.cbegin(),
172 values_dims.cbegin() + axis_,
174 std::multiplies<TIndex>());
175 const TIndex next_size = std::accumulate(
176 values_dims.cbegin() + axis_ + 1,
179 std::multiplies<TIndex>());
180 const TIndex src_offset_stride = k * next_size;
181 const TIndex dst_offset_stride = origin_dims[axis_] * next_size;
182 TIndex src_offset = 0;
183 TIndex dst_offset = 0;
184 for (TIndex i = 0; i < prev_size; ++i) {
185 for (TIndex j = 0; j < next_size; ++j) {
195 src_offset += src_offset_stride;
196 dst_offset += dst_offset_stride;
201 REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>);
202 REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>);
204 OPERATOR_SCHEMA(TopK)
207 .TensorInferenceFunction([](
const OperatorDef& def,
208 const vector<TensorShape>& in) {
209 vector<TensorShape> out = {in[0], in[0]};
210 ArgumentHelper helper(def);
211 auto k = helper.GetSingleArgument(
"k", -1);
212 auto dims_size = in[0].dims_size();
213 out[0].set_dims(dims_size - 1, k);
214 out[1].set_dims(dims_size - 1, k);
215 out[1].set_data_type(TensorProto_DataType_INT32);
216 if (def.output_size() > 2) {
217 TensorShape flatten_indices_shape;
218 flatten_indices_shape.set_data_type(TensorProto_DataType_INT32);
219 flatten_indices_shape.add_dims(
221 in[0].dims().begin(),
222 in[0].dims().end() - 1,
224 std::multiplies<long>()) *
226 out.push_back(flatten_indices_shape);
231 Retrieve the top-K elements for the last dimension. Given an input tensor of 232 shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs: 233 -Value tensor of shape [a_1, a_2, ..., a_n, k] which contains the values of 234 the top k elements along the last dimension 235 -Index tensor of shape [a_1, a_2, ..., a_n, k] which contains the indices 236 of the top k elements (original indices from the input tensor). 238 Given two equivalent values, this operator uses the indices along the last dim- 239 ension as a tiebreaker. That is, the element with the lower index will appear 242 .Input(0, "X",
"Tensor of shape [a_1, a_2, ..., a_n, r]")
246 "Tensor of shape [a_1, a_2, ..., a_n, k] containing" 247 " top K values from the input tensor")
251 "Tensor of shape [a_1, a_2, ..., a_n, k] containing" 252 " the corresponding input tensor indices for the top K values.")
256 "Tensor of shape [a_1 * a_2 * ... * a_n * k] containing the indices " 257 "into the flatten input")
258 .Arg(
"k",
"Number of top elements to retrieve");
260 OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1);
263 using GradientMakerBase::GradientMakerBase;
264 vector<OperatorDef> GetGradientDefs()
override {
265 return SingleGradientDef(
268 vector<string>{GO(0), O(1), I(0)},
269 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...