Caffe2 - C++ API
A deep learning, cross platform ML framework
momentum_sgd_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 
5 namespace caffe2 {
6 
7 template <typename Context>
8 void momentum_sgd_update(
9  const int N,
10  const float* g,
11  const float* m,
12  float* ng,
13  float* nm,
14  const float* lr,
15  const float momentum,
16  const bool nesterov,
17  float* param,
18  Context* /*context*/) {
19  const float LR = lr[0];
20  for (auto i = 0; i < N; ++i) {
21  if (!nesterov) {
22  const float adjusted_gradient = LR * g[i] + momentum * m[i];
23  nm[i] = adjusted_gradient;
24  ng[i] = adjusted_gradient;
25  } else {
26  const float mi = m[i];
27  const float mi_new = momentum * mi + LR * g[i];
28  nm[i] = mi_new;
29  ng[i] = (1 + momentum) * mi_new - momentum * mi;
30  }
31 
32  if (param) {
33  param[i] -= ng[i];
34  }
35  }
36 }
37 
38 template <typename T, class Context>
39 class MomentumSGDOp final : public Operator<Context> {
40  public:
41  USE_OPERATOR_CONTEXT_FUNCTIONS;
42  MomentumSGDOp(const OperatorDef& operator_def, Workspace* ws)
43  : Operator<Context>(operator_def, ws),
44  momentum_(OperatorBase::GetSingleArgument<T>("momentum", 0.0)),
45  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
46 
47  bool RunOnDevice() override {
48  // Iter live on the CPU
49  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(GRAD));
50  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(MOMENTUM));
51  CAFFE_ENFORCE(Input(LR).size() == 1);
52  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
53  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
54  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
55 
56  momentum_sgd_update<Context>(
57  Input(GRAD).size(),
58  Input(GRAD).template data<T>(),
59  Input(MOMENTUM).template data<T>(),
60  Output(OUTPUT_GRAD)->template mutable_data<T>(),
61  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
62  Input(LR).template data<T>(),
63  momentum_,
64  nesterov_,
65  NULL,
66  &context_);
67  return true;
68  }
69 
70  protected:
71  T momentum_{0.9};
72  bool nesterov_;
73  INPUT_TAGS(GRAD, MOMENTUM, LR);
74  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM);
75 };
76 
77 template <typename T, class Context>
78 class MomentumSGDUpdateOp final : public Operator<Context> {
79  public:
80  USE_OPERATOR_CONTEXT_FUNCTIONS;
81  MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
82  : Operator<Context>(operator_def, ws),
83  momentum_(OperatorBase::GetSingleArgument<T>("momentum", 0.0)),
84  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
85 
86  bool RunOnDevice() override {
87  // Iter live on the CPU
88  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(GRAD));
89  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(MOMENTUM));
90  CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
91  CAFFE_ENFORCE_EQ(Input(GRAD).size(), Input(MOMENTUM).size());
92  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
93  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
94 
95  momentum_sgd_update<Context>(
96  Input(GRAD).size(),
97  Input(GRAD).template data<T>(),
98  Input(MOMENTUM).template data<T>(),
99  Output(OUTPUT_GRAD)->template mutable_data<T>(),
100  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
101  Input(LR).template data<T>(),
102  momentum_,
103  nesterov_,
104  Output(OUTPUT_PARAM)->template mutable_data<T>(),
105  &context_);
106  return true;
107  }
108 
109  protected:
110  T momentum_{0.9};
111  bool nesterov_;
112  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
113  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
114 };
115 
116 template <typename T, class Context>
117 class SparseMomentumSGDUpdateOp final : public Operator<Context> {
118  public:
119  USE_OPERATOR_CONTEXT_FUNCTIONS;
120  SparseMomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
121  : Operator<Context>(operator_def, ws),
122  momentum_(OperatorBase::GetSingleArgument<T>("momentum", 0.0)),
123  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
124 
125  bool RunOnDevice() override {
126  // Resize [potentially] out-of-place blobs
127  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
128 
129  // Enforce shapes
130  CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
131  CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENTUM).size());
132  CAFFE_ENFORCE_EQ(Input(PARAM).size_from_dim(1),
133  Input(GRAD).size_from_dim(Input(INDICES).ndim()));
134 
136  this, Input(INDICES));
137  }
138 
139  template <typename SIndex>
140  bool DoRunWithType() {
141  auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
142  auto n = Input(GRAD).size() / block_size;
143 
144  const auto* gradIn = Input(GRAD).template data<T>();
145  const auto* momentumIn = Input(MOMENTUM).template data<T>();
146  const auto* lr = Input(LR).template data<T>();
147  const auto* paramIn = Input(PARAM).template data<T>();
148  const auto* indices = Input(INDICES).template data<SIndex>();
149 
150  auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
151  auto* momentumOut = Output(OUTPUT_MOMENTUM)->template mutable_data<T>();
152  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
153 
154  for (auto i = 0; i < n; ++i) {
155  auto idx = indices[i];
156  auto offsetI = i * block_size;
157  auto offsetIdx = idx * block_size;
158 
159  CAFFE_ENFORCE(offsetIdx + block_size <= Input(PARAM).size());
160  CAFFE_ENFORCE(offsetI + block_size <= Input(GRAD).size());
161 
162  momentum_sgd_update<Context>(
163  block_size,
164  gradIn + offsetI,
165  momentumIn + offsetIdx,
166  gradOut + offsetI,
167  momentumOut + offsetIdx,
168  lr,
169  momentum_,
170  nesterov_,
171  paramOut + offsetIdx,
172  &context_);
173  }
174  return true;
175  }
176 
177  protected:
178  T momentum_;
179  bool nesterov_;
180  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM, INDICES);
181  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
182 };
183 }
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...