Caffe2 - C++ API
A deep learning, cross platform ML framework
box_with_nms_limit_op.cc
1 #include "box_with_nms_limit_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 #include "generate_proposals_op_util_nms.h"
4 
5 #ifdef CAFFE2_USE_MKL
6 #include "caffe2/mkl/operators/operator_fallback_mkl.h"
7 #endif // CAFFE2_USE_MKL
8 
9 namespace caffe2 {
10 
11 namespace {
12 
13 template <class Derived, class Func>
14 vector<int> filter_with_indices(
15  const Eigen::ArrayBase<Derived>& array,
16  const vector<int>& indices,
17  const Func& func) {
18  vector<int> ret;
19  for (auto& cur : indices) {
20  if (func(array[cur])) {
21  ret.push_back(cur);
22  }
23  }
24  return ret;
25 }
26 
27 } // namespace
28 
29 template <>
30 bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
31  const auto& tscores = Input(0);
32  const auto& tboxes = Input(1);
33  auto* out_scores = Output(0);
34  auto* out_boxes = Output(1);
35  auto* out_classes = Output(2);
36 
37  // tscores: (num_boxes, num_classes), 0 for background
38  if (tscores.ndim() == 4) {
39  CAFFE_ENFORCE_EQ(tscores.dim(2), 1, tscores.dim(2));
40  CAFFE_ENFORCE_EQ(tscores.dim(3), 1, tscores.dim(3));
41  } else {
42  CAFFE_ENFORCE_EQ(tscores.ndim(), 2, tscores.ndim());
43  }
44  CAFFE_ENFORCE(tscores.template IsType<float>(), tscores.meta().name());
45  // tboxes: (num_boxes, num_classes * 4)
46  if (tboxes.ndim() == 4) {
47  CAFFE_ENFORCE_EQ(tboxes.dim(2), 1, tboxes.dim(2));
48  CAFFE_ENFORCE_EQ(tboxes.dim(3), 1, tboxes.dim(3));
49  } else {
50  CAFFE_ENFORCE_EQ(tboxes.ndim(), 2, tboxes.ndim());
51  }
52  CAFFE_ENFORCE(tboxes.template IsType<float>(), tboxes.meta().name());
53 
54  int N = tscores.dim(0);
55  int num_classes = tscores.dim(1);
56 
57  CAFFE_ENFORCE_EQ(N, tboxes.dim(0));
58  CAFFE_ENFORCE_EQ(num_classes * 4, tboxes.dim(1));
59 
60  int batch_size = 1;
61  vector<float> batch_splits_default(1, tscores.dim(0));
62  const float* batch_splits_data = batch_splits_default.data();
63  if (InputSize() > 2) {
64  // tscores and tboxes have items from multiple images in a batch. Get the
65  // corresponding batch splits from input.
66  const auto& tbatch_splits = Input(2);
67  CAFFE_ENFORCE_EQ(tbatch_splits.ndim(), 1);
68  batch_size = tbatch_splits.dim(0);
69  batch_splits_data = tbatch_splits.data<float>();
70  }
71  Eigen::Map<const EArrXf> batch_splits(batch_splits_data, batch_size);
72  CAFFE_ENFORCE_EQ(batch_splits.sum(), N);
73 
74  out_scores->Resize(0);
75  out_boxes->Resize(0, 4);
76  out_classes->Resize(0);
77 
78  TensorCPU* out_keeps = nullptr;
79  TensorCPU* out_keeps_size = nullptr;
80  if (OutputSize() > 4) {
81  out_keeps = Output(4);
82  out_keeps_size = Output(5);
83  out_keeps->Resize(0);
84  out_keeps_size->Resize(batch_size, num_classes);
85  }
86 
87  vector<int> total_keep_per_batch(batch_size);
88  int offset = 0;
89  for (int b = 0; b < batch_splits.size(); ++b) {
90  int num_boxes = batch_splits(b);
91  Eigen::Map<const ERArrXXf> scores(
92  tscores.data<float>() + offset * tscores.dim(1),
93  num_boxes,
94  tscores.dim(1));
95  Eigen::Map<const ERArrXXf> boxes(
96  tboxes.data<float>() + offset * tboxes.dim(1),
97  num_boxes,
98  tboxes.dim(1));
99 
100  // To store updated scores if SoftNMS is used
101  ERArrXXf soft_nms_scores(num_boxes, tscores.dim(1));
102  vector<vector<int>> keeps(num_classes);
103 
104  // Perform nms to each class
105  // skip j = 0, because it's the background class
106  int total_keep_count = 0;
107  for (int j = 1; j < num_classes; j++) {
108  auto cur_scores = scores.col(j);
109  auto inds = utils::GetArrayIndices(cur_scores > score_thres_);
110  auto cur_boxes = boxes.block(0, j * 4, boxes.rows(), 4);
111 
112  if (soft_nms_enabled_) {
113  auto cur_soft_nms_scores = soft_nms_scores.col(j);
114  keeps[j] = utils::soft_nms_cpu(
115  &cur_soft_nms_scores,
116  cur_boxes,
117  cur_scores,
118  inds,
119  soft_nms_sigma_,
120  nms_thres_,
121  soft_nms_min_score_thres_,
122  soft_nms_method_);
123  } else {
124  std::sort(
125  inds.data(),
126  inds.data() + inds.size(),
127  [&cur_scores](int lhs, int rhs) {
128  return cur_scores(lhs) > cur_scores(rhs);
129  });
130  keeps[j] = utils::nms_cpu(cur_boxes, cur_scores, inds, nms_thres_);
131  }
132  total_keep_count += keeps[j].size();
133  }
134 
135  if (soft_nms_enabled_) {
136  // Re-map scores to the updated SoftNMS scores
137  new (&scores) Eigen::Map<const ERArrXXf>(
138  soft_nms_scores.data(),
139  soft_nms_scores.rows(),
140  soft_nms_scores.cols());
141  }
142 
143  // Limit to max_per_image detections *over all classes*
144  if (detections_per_im_ > 0 && total_keep_count > detections_per_im_) {
145  // merge all scores together and sort
146  auto get_all_scores_sorted = [&scores, &keeps, total_keep_count]() {
147  EArrXf ret(total_keep_count);
148 
149  int ret_idx = 0;
150  for (int i = 1; i < keeps.size(); i++) {
151  auto& cur_keep = keeps[i];
152  auto cur_scores = scores.col(i);
153  auto cur_ret = ret.segment(ret_idx, cur_keep.size());
154  utils::GetSubArray(cur_scores, utils::AsEArrXt(keeps[i]), &cur_ret);
155  ret_idx += cur_keep.size();
156  }
157 
158  std::sort(ret.data(), ret.data() + ret.size());
159 
160  return ret;
161  };
162 
163  // Compute image thres based on all classes
164  auto all_scores_sorted = get_all_scores_sorted();
165  DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
166  auto image_thresh =
167  all_scores_sorted[all_scores_sorted.size() - detections_per_im_];
168 
169  total_keep_count = 0;
170  // filter results with image_thresh
171  for (int j = 1; j < num_classes; j++) {
172  auto& cur_keep = keeps[j];
173  auto cur_scores = scores.col(j);
174  keeps[j] = filter_with_indices(
175  cur_scores, cur_keep, [&image_thresh](float sc) {
176  return sc >= image_thresh;
177  });
178  total_keep_count += keeps[j].size();
179  }
180  }
181  total_keep_per_batch[b] = total_keep_count;
182 
183  // Write results
184  int cur_start_idx = out_scores->dim(0);
185  out_scores->Extend(total_keep_count, 50, &context_);
186  out_boxes->Extend(total_keep_count, 50, &context_);
187  out_classes->Extend(total_keep_count, 50, &context_);
188 
189  int cur_out_idx = 0;
190  for (int j = 1; j < num_classes; j++) {
191  auto cur_scores = scores.col(j);
192  auto cur_boxes = boxes.block(0, j * 4, boxes.rows(), 4);
193  auto& cur_keep = keeps[j];
194  Eigen::Map<EArrXf> cur_out_scores(
195  out_scores->mutable_data<float>() + cur_start_idx + cur_out_idx,
196  cur_keep.size());
197  Eigen::Map<ERArrXXf> cur_out_boxes(
198  out_boxes->mutable_data<float>() + (cur_start_idx + cur_out_idx) * 4,
199  cur_keep.size(),
200  4);
201  Eigen::Map<EArrXf> cur_out_classes(
202  out_classes->mutable_data<float>() + cur_start_idx + cur_out_idx,
203  cur_keep.size());
204 
205  utils::GetSubArray(
206  cur_scores, utils::AsEArrXt(cur_keep), &cur_out_scores);
207  utils::GetSubArrayRows(
208  cur_boxes, utils::AsEArrXt(cur_keep), &cur_out_boxes);
209  for (int k = 0; k < cur_keep.size(); k++) {
210  cur_out_classes[k] = static_cast<float>(j);
211  }
212 
213  cur_out_idx += cur_keep.size();
214  }
215 
216  if (out_keeps) {
217  out_keeps->Extend(total_keep_count, 50, &context_);
218 
219  Eigen::Map<EArrXi> out_keeps_arr(
220  out_keeps->mutable_data<int>() + cur_start_idx, total_keep_count);
221  Eigen::Map<EArrXi> cur_out_keeps_size(
222  out_keeps_size->mutable_data<int>() + b * num_classes, num_classes);
223 
224  cur_out_idx = 0;
225  for (int j = 0; j < num_classes; j++) {
226  out_keeps_arr.segment(cur_out_idx, keeps[j].size()) =
227  utils::AsEArrXt(keeps[j]);
228  cur_out_keeps_size[j] = keeps[j].size();
229  cur_out_idx += keeps[j].size();
230  }
231  }
232 
233  offset += num_boxes;
234  }
235 
236  if (OutputSize() > 3) {
237  auto* batch_splits_out = Output(3);
238  batch_splits_out->Resize(batch_size);
239  Eigen::Map<EArrXf> batch_splits_out_map(
240  batch_splits_out->mutable_data<float>(), batch_size);
241  batch_splits_out_map =
242  Eigen::Map<const EArrXi>(total_keep_per_batch.data(), batch_size)
243  .cast<float>();
244  }
245 
246  return true;
247 }
248 
249 namespace {
250 
251 REGISTER_CPU_OPERATOR(BoxWithNMSLimit, BoxWithNMSLimitOp<CPUContext>);
252 
253 #ifdef CAFFE2_HAS_MKL_DNN
254 REGISTER_MKL_OPERATOR(
255  BoxWithNMSLimit,
256  mkl::MKLFallbackOp<BoxWithNMSLimitOp<CPUContext>>);
257 #endif // CAFFE2_HAS_MKL_DNN
258 
259 OPERATOR_SCHEMA(BoxWithNMSLimit)
260  .NumInputs(2, 3)
261  .NumOutputs(3, 6)
262  .SetDoc(R"DOC(
263 Apply NMS to each class (except background) and limit the number of
264 returned boxes.
265 )DOC")
266  .Arg("score_thresh", "(float) TEST.SCORE_THRESH")
267  .Arg("nms", "(float) TEST.NMS")
268  .Arg("detections_per_im", "(int) TEST.DEECTIONS_PER_IM")
269  .Arg("soft_nms_enabled", "(bool) TEST.SOFT_NMS.ENABLED")
270  .Arg("soft_nms_method", "(string) TEST.SOFT_NMS.METHOD")
271  .Arg("soft_nms_sigma", "(float) TEST.SOFT_NMS.SIGMA")
272  .Arg(
273  "soft_nms_min_score_thres",
274  "(float) Lower bound on updated scores to discard boxes")
275  .Input(0, "scores", "Scores, size (count, num_classes)")
276  .Input(
277  1,
278  "boxes",
279  "Bounding box for each class, size (count, num_classes * 4)")
280  .Input(
281  2,
282  "batch_splits",
283  "Tensor of shape (batch_size) with each element denoting the number "
284  "of RoIs/boxes belonging to the corresponding image in batch. "
285  "Sum should add up to total count of scores/boxes.")
286  .Output(0, "scores", "Filtered scores, size (n)")
287  .Output(1, "boxes", "Filtered boxes, size (n, 4)")
288  .Output(2, "classes", "Class id for each filtered score/box, size (n)")
289  .Output(
290  3,
291  "batch_splits",
292  "Output batch splits for scores/boxes after applying NMS")
293  .Output(4, "keeps", "Optional filtered indices, size (n)")
294  .Output(
295  5,
296  "keeps_size",
297  "Optional number of filtered indices per class, size (num_classes)");
298 
299 SHOULD_NOT_DO_GRADIENT(BoxWithNMSLimit);
300 
301 } // namespace
302 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...