Caffe2 - C++ API
A deep learning, cross platform ML framework
normalize_op.h
1 #ifndef CAFFE2_OPERATORS_NORMALIZE_OP_H_
2 #define CAFFE2_OPERATORS_NORMALIZE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class NormalizeOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  NormalizeOp(const OperatorDef& def, Workspace* ws)
15  : Operator<Context>(def, ws) {}
16 
17  bool RunOnDevice() override {
18  const auto& x = Input(0);
19  auto* y = Output(0);
20  const auto* xData = x.template data<T>();
21  y->ResizeLike(x);
22  auto* yData = y->template mutable_data<T>();
23 
24  const auto canonical_axis = x.canonical_axis_index(
25  OperatorBase::GetSingleArgument<int>("axis", -1));
26  const int m = x.dim32(canonical_axis);
27  const int n = x.size() / m;
28  const int sf = x.size_from_dim(canonical_axis + 1);
29  DoNormalize(xData, yData, m, n, sf);
30  return true;
31  }
32 
33  private:
34  void
35  DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
36 };
37 
38 template <typename T, class Context>
39 class NormalizeGradientOp final : public Operator<Context> {
40  public:
41  USE_OPERATOR_CONTEXT_FUNCTIONS;
42  NormalizeGradientOp(const OperatorDef& def, Workspace* ws)
43  : Operator<Context>(def, ws) {}
44 
45  bool RunOnDevice() override {
46  const auto& x = Input(0);
47  const auto& gOut = Input(GRAD_OUT);
48  auto* gIn = Output(GRAD_IN);
49  gIn->ResizeLike(gOut);
50 
51  const auto* xData = x.template data<T>();
52  const auto* gOutData = gOut.template data<T>();
53  auto* gInData = gIn->template mutable_data<T>();
54 
55  const auto canonical_axis = x.canonical_axis_index(
56  OperatorBase::GetSingleArgument<int>("axis", -1));
57  const int m = x.dim32(canonical_axis);
58  const int n = x.size() / m;
59  const int sf = x.size_from_dim(canonical_axis + 1);
60  DoNormalize(xData, gOutData, gInData, m, n, sf);
61  return true;
62  }
63 
64  private:
65  void DoNormalize(
66  const T* xData,
67  const T* gOutData,
68  T* gInData,
69  const int m,
70  const int n,
71  const int sf);
72 
73  INPUT_TAGS(INPUT, GRAD_OUT);
74  OUTPUT_TAGS(GRAD_IN);
75 };
76 
77 } // namespace caffe2
78 
79 #endif // CAFFE2_OPERATORS_NORMALIZE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...