Caffe2 - C++ API
A deep learning, cross platform ML framework
max_pool_with_index.h
1 #ifndef CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
2 #define CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
3 
4 #include <cfloat>
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/context_gpu.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 #include "caffe2/operators/conv_pool_op_base.h"
10 #include "caffe2/operators/pool_op.h"
11 #include "caffe2/utils/math.h"
12 
13 namespace caffe2 {
14 
15 class MaxPoolWithIndexOp final : public ConvPoolOpBase<CUDAContext> {
16  public:
17  USE_CONV_POOL_BASE_FUNCTIONS(CUDAContext);
18  MaxPoolWithIndexOp(const OperatorDef& operator_def, Workspace* ws)
19  : ConvPoolOpBase<CUDAContext>(operator_def, ws) {}
20  ~MaxPoolWithIndexOp() {}
21 
22  template <typename T>
23  bool DoRunWithType();
24 
25  bool RunOnDevice() override;
26 
27  // Input: X
28  // Output: Y, mask
29 };
30 
31 class MaxPoolWithIndexGradientOp final : public ConvPoolOpBase<CUDAContext> {
32  public:
33  USE_CONV_POOL_BASE_FUNCTIONS(CUDAContext);
34  MaxPoolWithIndexGradientOp(const OperatorDef& operator_def, Workspace* ws)
35  : ConvPoolOpBase<CUDAContext>(operator_def, ws) {}
37 
38  template <typename T>
39  bool DoRunWithType();
40 
41  bool RunOnDevice() override;
42 
43  // Input: X, dY, mask
44  // Output: dX
45 };
46 
47 }; // namespace caffe2
48 
49 #endif // CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...