Caffe2 - C++ API
A deep learning, cross platform ML framework
pool_op.h
1 #ifndef CAFFE2_OPERATORS_POOL_OP_H_
2 #define CAFFE2_OPERATORS_POOL_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/operators/conv_pool_op_base.h"
9 #include "caffe2/utils/math.h"
10 
11 namespace caffe2 {
12 
13 template <typename T, class Context, typename PoolType>
14 class PoolOp final : public ConvPoolOpBase<Context> {
15  public:
16  USE_CONV_POOL_BASE_FUNCTIONS(Context);
17  PoolOp(const OperatorDef& operator_def, Workspace* ws)
18  : ConvPoolOpBase<Context>(operator_def, ws) {
19  for (int i = 0; i < kernel_.size(); ++i) {
20  CAFFE_ENFORCE(
21  dilation_[i] == 1, "Pooling op does not support dilation right now.");
22  }
23  if (!global_pooling_) {
24  for (int i = 0; i < kernel_.size(); ++i) {
25  CAFFE_ENFORCE(
26  pads_[i] < kernel_[i] && pads_[i + kernel_.size()] < kernel_[i],
27  "Pad should be smaller than kernel.");
28  }
29  }
30  }
31  ~PoolOp() {}
32 
33  bool RunOnDeviceWithOrderNCHW() override;
34  bool RunOnDeviceWithOrderNHWC() override;
35 
36  // Input: X
37  // Output: Y
38 };
39 
40 template <typename T, class Context, class PoolType>
41 class PoolGradientOp final : public ConvPoolOpBase<Context> {
42  public:
43  USE_CONV_POOL_BASE_FUNCTIONS(Context);
44  PoolGradientOp(const OperatorDef& operator_def, Workspace* ws)
45  : ConvPoolOpBase<Context>(operator_def, ws) {}
46  ~PoolGradientOp() {}
47 
48  bool RunOnDeviceWithOrderNCHW() override;
49  bool RunOnDeviceWithOrderNHWC() override;
50 
51  // Input: X, Y, dY
52  // Output: dX
53 };
54 
55 } // namespace caffe2
56 
57 #endif // CAFFE2_OPERATORS_POOL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...