Caffe2 - C++ API
A deep learning, cross platform ML framework
local_response_normalization_op.h
1 #ifndef CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
2 #define CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class LRNOpBase : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  LRNOpBase(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  size_(OperatorBase::GetSingleArgument<int>("size", 0)),
18  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
19  beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
20  bias_(OperatorBase::GetSingleArgument<float>("bias", 1)),
21  order_(StringToStorageOrder(
22  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
23  pre_pad_((size_ - 1) / 2) {
24  DCHECK_GT(size_, 0);
25  DCHECK_EQ(size_ % 2, 1);
26  DCHECK_GT(alpha_, 0);
27  DCHECK_GT(beta_, 0);
28  }
29 
30  bool RunOnDevice() override {
31  switch (order_) {
32  case StorageOrder::NHWC:
33  return RunOnDeviceWithOrderNHWC();
34  case StorageOrder::NCHW:
35  return RunOnDeviceWithOrderNCHW();
36  default:
37  LOG(FATAL) << "Unknown storage order: " << order_;
38  }
39  // To suppress old compiler warnings
40  return true;
41  }
42 
43  virtual bool RunOnDeviceWithOrderNCHW() = 0;
44  virtual bool RunOnDeviceWithOrderNHWC() = 0;
45 
46  protected:
47  const int size_;
48  const float alpha_;
49  const float beta_;
50  const float bias_;
51  const StorageOrder order_;
52  const int pre_pad_;
53  // Input: X; Output: Y, scale.
54 };
55 
56 template <typename T, class Context>
57 class LRNOp final : public LRNOpBase<T, Context> {
58  public:
59  USE_OPERATOR_CONTEXT_FUNCTIONS;
60  LRNOp(const OperatorDef& operator_def, Workspace* ws)
61  : LRNOpBase<T, Context>(operator_def, ws) {}
62 
63  bool RunOnDeviceWithOrderNCHW() override;
64  bool RunOnDeviceWithOrderNHWC() override;
65 
66  protected:
67  // Input: X; Output: Y, scale.
68  OUTPUT_TAGS(OUTPUT, SCALE);
69  Tensor<Context>* scale_ = nullptr;
70  Tensor<Context> local_scale_tensor_;
71 };
72 
73 template <typename T, class Context>
74 class LRNGradientOp final : public LRNOpBase<T, Context> {
75  public:
76  USE_OPERATOR_CONTEXT_FUNCTIONS;
77  LRNGradientOp(const OperatorDef& operator_def, Workspace* ws)
78  : LRNOpBase<T, Context>(operator_def, ws) {}
79 
80  bool RunOnDeviceWithOrderNCHW() override;
81  bool RunOnDeviceWithOrderNHWC() override;
82 
83  protected:
84  // Input: X, Y, scale, dY; Output: dX
85  INPUT_TAGS(INPUT, OUTPUT, SCALE, OUTPUT_GRAD);
86  Tensor<Context>* scale_ = nullptr;
87  Tensor<Context> local_scale_tensor_;
88 };
89 
90 } // namespace caffe2
91 
92 #endif // CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...