Caffe2 - C++ API
A deep learning, cross platform ML framework
generate_proposals_op.h
1 #ifndef CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
2 #define CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/eigen_utils.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 namespace utils {
12 
13 // A sub tensor view
14 template <class T>
16  public:
17  ConstTensorView(const T* data, const std::vector<int>& dims)
18  : data_(data), dims_(dims) {}
19 
20  int ndim() const {
21  return dims_.size();
22  }
23  const std::vector<int>& dims() const {
24  return dims_;
25  }
26  int dim(int i) const {
27  DCHECK_LE(i, dims_.size());
28  return dims_[i];
29  }
30  const T* data() const {
31  return data_;
32  }
33  size_t size() const {
34  return std::accumulate(
35  dims_.begin(), dims_.end(), 1, std::multiplies<size_t>());
36  }
37 
38  private:
39  const T* data_ = nullptr;
40  std::vector<int> dims_;
41 };
42 
43 // Generate a list of bounding box shapes for each pixel based on predefined
44 // bounding box shapes 'anchors'.
45 // anchors: predefined anchors, size(A, 4)
46 // Return: all_anchors_vec: (H * W, A * 4)
47 // Need to reshape to (H * W * A, 4) to match the format in python
48 ERMatXf ComputeAllAnchors(
49  const TensorCPU& anchors,
50  int height,
51  int width,
52  float feat_stride);
53 
54 } // namespace utils
55 
56 // C++ implementation of GenerateProposalsOp
57 // Generate bounding box proposals for Faster RCNN. The propoasls are generated
58 // for a list of images based on image score 'score', bounding box
59 // regression result 'deltas' as well as predefined bounding box shapes
60 // 'anchors'. Greedy non-maximum suppression is applied to generate the
61 // final bounding boxes.
62 // Reference: detectron/lib/ops/generate_proposals.py
63 template <class Context>
64 class GenerateProposalsOp final : public Operator<Context> {
65  public:
66  USE_OPERATOR_CONTEXT_FUNCTIONS;
67  GenerateProposalsOp(const OperatorDef& operator_def, Workspace* ws)
68  : Operator<Context>(operator_def, ws),
69  spatial_scale_(
70  OperatorBase::GetSingleArgument<float>("spatial_scale", 1.0 / 16)),
71  feat_stride_(1.0 / spatial_scale_),
72  rpn_pre_nms_topN_(
73  OperatorBase::GetSingleArgument<int>("pre_nms_topN", 6000)),
74  rpn_post_nms_topN_(
75  OperatorBase::GetSingleArgument<int>("post_nms_topN", 300)),
76  rpn_nms_thresh_(
77  OperatorBase::GetSingleArgument<float>("nms_thresh", 0.7f)),
78  rpn_min_size_(OperatorBase::GetSingleArgument<float>("min_size", 16)),
79  correct_transform_coords_(OperatorBase::GetSingleArgument<bool>(
80  "correct_transform_coords",
81  false)) {}
82 
84 
85  bool RunOnDevice() override;
86 
87  // Generate bounding box proposals for a given image
88  // im_info: [height, width, im_scale]
89  // all_anchors: (H * W * A, 4)
90  // bbox_deltas_tensor: (4 * A, H, W)
91  // scores_tensor: (A, H, W)
92  // out_boxes: (n, 5)
93  // out_probs: n
94  void ProposalsForOneImage(
95  const Eigen::Array3f& im_info,
96  const Eigen::Map<const ERMatXf>& all_anchors,
97  const utils::ConstTensorView<float>& bbox_deltas_tensor,
98  const utils::ConstTensorView<float>& scores_tensor,
99  ERArrXXf* out_boxes,
100  EArrXf* out_probs) const;
101 
102  protected:
103  // spatial_scale_ must be declared before feat_stride_
104  float spatial_scale_{1.0};
105  float feat_stride_{1.0};
106 
107  // RPN_PRE_NMS_TOP_N
108  int rpn_pre_nms_topN_{6000};
109  // RPN_POST_NMS_TOP_N
110  int rpn_post_nms_topN_{300};
111  // RPN_NMS_THRESH
112  float rpn_nms_thresh_{0.7};
113  // RPN_MIN_SIZE
114  float rpn_min_size_{16};
115  // Correct bounding box transform coordates, see bbox_transform() in boxes.py
116  // Set to true to match the detectron code, set to false for backward
117  // compatibility
118  bool correct_transform_coords_{false};
119 };
120 
121 } // namespace caffe2
122 
123 #endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...