Caffe2 - C++ API
A deep learning, cross platform ML framework
roi_pool_op.h
1 #ifndef ROI_POOL_OP_H_
2 #define ROI_POOL_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class RoIPoolOp final : public Operator<Context> {
13  public:
14  RoIPoolOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
17  order_(StringToStorageOrder(
18  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
19  pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
20  pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
21  spatial_scale_(
22  OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)) {
23  CAFFE_ENFORCE(
24  (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 2),
25  "Output size mismatch.");
26  CAFFE_ENFORCE_GT(spatial_scale_, 0);
27  CAFFE_ENFORCE_GT(pooled_height_, 0);
28  CAFFE_ENFORCE_GT(pooled_width_, 0);
29  CAFFE_ENFORCE_EQ(
30  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
31  }
32  USE_OPERATOR_CONTEXT_FUNCTIONS;
33 
34  bool RunOnDevice() override;
35 
36  protected:
37  bool is_test_;
38  StorageOrder order_;
39  int pooled_height_;
40  int pooled_width_;
41  float spatial_scale_;
42 };
43 
44 template <typename T, class Context>
45 class RoIPoolGradientOp final : public Operator<Context> {
46  public:
47  RoIPoolGradientOp(const OperatorDef& def, Workspace* ws)
48  : Operator<Context>(def, ws),
49  spatial_scale_(
50  OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)),
51  pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
52  pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
53  order_(StringToStorageOrder(
54  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
55  CAFFE_ENFORCE_GT(spatial_scale_, 0);
56  CAFFE_ENFORCE_GT(pooled_height_, 0);
57  CAFFE_ENFORCE_GT(pooled_width_, 0);
58  CAFFE_ENFORCE_EQ(
59  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
60  }
61  USE_OPERATOR_CONTEXT_FUNCTIONS;
62 
63  bool RunOnDevice() override {
64  CAFFE_NOT_IMPLEMENTED;
65  }
66 
67  protected:
68  float spatial_scale_;
69  int pooled_height_;
70  int pooled_width_;
71  StorageOrder order_;
72 };
73 
74 } // namespace caffe2
75 
76 #endif // ROI_POOL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...