Caffe2 - C++ API
A deep learning, cross platform ML framework
adam_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 
5 namespace caffe2 {
6 
7 template <typename Context>
8 void adam_update(
9  int N,
10  const float* g,
11  const float* m,
12  const float* v,
13  float* ng,
14  float* nm,
15  float* nv,
16  float beta1,
17  float beta2,
18  float eps_hat,
19  float correction,
20  const float* lr,
21  Context* /*context*/) {
22  for (auto i = 0; i < N; ++i) {
23  float gi = g[i];
24  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
25  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
26  ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
27  }
28 }
29 
30 template <typename Context>
31 void adam_compute(
32  int N,
33  const float* w,
34  const float* g,
35  const float* m,
36  const float* v,
37  float* nw,
38  float* nm,
39  float* nv,
40  float beta1,
41  float beta2,
42  float eps_hat,
43  float correction,
44  const float* lr,
45  Context* /*context*/) {
46  for (auto i = 0; i < N; ++i) {
47  float gi = g[i];
48  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
49  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
50  float ng = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
51  nw[i] = w[i] + ng;
52  }
53 }
54 
55 template <typename T, class Context>
56 class AdamOp final : public Operator<Context> {
57  public:
58  USE_OPERATOR_CONTEXT_FUNCTIONS;
59  AdamOp(const OperatorDef& operator_def, Workspace* ws)
60  : Operator<Context>(operator_def, ws),
61  beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
62  beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
63  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
64  bool RunOnDevice() override {
65  // Iter live on the CPU
66  CAFFE_ENFORCE(OperatorBase::InputIsType<TensorCPU>(ITER));
67  CAFFE_ENFORCE(Input(LR).size() == 1);
68  CAFFE_ENFORCE(Input(GRAD).size() == Input(PARAM).size());
69  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENT_1).size());
70  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENT_2).size());
71  Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
72  Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
73  Output(OUTPUT_MOMENT_2)->ResizeLike(Input(MOMENT_2));
74 
75  const auto iter =
76  OperatorBase::Input<TensorCPU>(ITER).template data<int64_t>()[0];
77 
78  const auto t = iter + 1;
79  const auto correction =
80  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
81  adam_compute<Context>(
82  Input(GRAD).size(),
83  Input(PARAM).template data<T>(),
84  Input(GRAD).template data<T>(),
85  Input(MOMENT_1).template data<T>(),
86  Input(MOMENT_2).template data<T>(),
87  Output(OUTPUT_PARAM)->template mutable_data<T>(),
88  Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
89  Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
90  beta1_,
91  beta2_,
92  epsilon_,
93  correction,
94  Input(LR).template data<T>(),
95  &context_);
96  return true;
97  }
98 
99  protected:
100  T beta1_{0.9};
101  T beta2_{0.999};
102  T epsilon_{1e-8};
103  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
104  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
105 };
106 
107 template <typename T, class Context>
108 class SparseAdamOp final : public Operator<Context> {
109  public:
110  USE_OPERATOR_CONTEXT_FUNCTIONS;
111  SparseAdamOp(const OperatorDef& operator_def, Workspace* ws)
112  : Operator<Context>(operator_def, ws),
113  beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
114  beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
115  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
116 
117  bool RunOnDevice() override {
118  // Enforce shapes
119  CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_1).size());
120  CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_2).size());
121  CAFFE_ENFORCE_EQ(
122  Input(PARAM).size_from_dim(1),
123  Input(GRAD).size_from_dim(Input(INDICES).ndim()));
124  CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
125 
127  this, Input(INDICES));
128  }
129 
130  template <typename SIndex>
131  bool DoRunWithType() {
132  const auto* lr = Input(LR).template data<T>();
133  const auto iter =
134  OperatorBase::Input<TensorCPU>(ITER).template data<int64_t>()[0];
135 
136  const auto t = iter + 1;
137  const auto correction =
138  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
139 
140  auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
141  auto n = Input(GRAD).size() / block_size;
142 
143  const auto* paramIn = Input(PARAM).template data<T>();
144  const auto* indices = Input(INDICES).template data<SIndex>();
145  const auto* gradIn = Input(GRAD).template data<T>();
146  const auto* moment1In = Input(MOMENT_1).template data<T>();
147  const auto* moment2In = Input(MOMENT_2).template data<T>();
148  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
149  auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
150  auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
151 
152  for (auto i = 0; i < n; ++i) {
153  auto idx = indices[i];
154 
155  if (block_size == 1) {
156  float gi = gradIn[i];
157  float mi = moment1Out[idx] =
158  moment1In[idx] * beta1_ + gi * (1 - beta1_);
159  float vi = moment2Out[idx] =
160  moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
161  paramOut[idx] =
162  paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
163 
164  } else {
165  auto offsetI = i * block_size;
166  auto offsetIdx = idx * block_size;
167 
168 #ifndef NDEBUG
169  CAFFE_ENFORCE_GE(
170  Input(PARAM).size(),
171  block_size + offsetIdx,
172  this->debug_def().input(PARAM),
173  ", out of bound, idx:",
174  idx,
175  " for input i:",
176  i,
177  " and block size:",
178  block_size);
179  CAFFE_ENFORCE_GE(
180  Input(GRAD).size(),
181  block_size + offsetI,
182  this->debug_def().input(GRAD),
183  ", out of bound idx, idx:",
184  idx,
185  " for input i:",
186  i);
187 #endif
188 
189  adam_compute(
190  block_size,
191  paramIn + offsetIdx,
192  gradIn + offsetI,
193  moment1In + offsetIdx,
194  moment2In + offsetIdx,
195  paramOut + offsetIdx,
196  moment1Out + offsetIdx,
197  moment2Out + offsetIdx,
198  beta1_,
199  beta2_,
200  epsilon_,
201  correction,
202  lr,
203  &context_);
204  }
205  }
206  return true;
207  }
208 
209  protected:
210  T beta1_;
211  T beta2_;
212  T epsilon_;
213  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
214  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
215 };
216 
217 template <typename T, class Context>
218 class RowWiseSparseAdamOp final : public Operator<Context> {
219  public:
220  USE_OPERATOR_CONTEXT_FUNCTIONS;
221  RowWiseSparseAdamOp(const OperatorDef& operator_def, Workspace* ws)
222  : Operator<Context>(operator_def, ws),
223  beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
224  beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
225  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
226 
227  bool RunOnDevice() override {
228  // Enforce shapes
229  CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_1).size());
230  CAFFE_ENFORCE_EQ(Input(PARAM).dims()[0], Input(MOMENT_2).size());
231  CAFFE_ENFORCE_EQ(
232  Input(PARAM).size_from_dim(1),
233  Input(GRAD).size_from_dim(Input(INDICES).ndim()));
234  CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
235 
237  this, Input(INDICES));
238  }
239 
240  template <typename SIndex>
241  bool DoRunWithType() {
242  const auto* lr = Input(LR).template data<T>();
243  const auto iter =
244  OperatorBase::Input<TensorCPU>(ITER).template data<int64_t>()[0];
245 
246  const auto t = iter + 1;
247  const auto correction =
248  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
249 
250  auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
251  auto n = Input(GRAD).size() / block_size;
252 
253  const auto* paramIn = Input(PARAM).template data<T>();
254  const auto* indices = Input(INDICES).template data<SIndex>();
255  const auto* gradIn = Input(GRAD).template data<T>();
256  const auto* moment1In = Input(MOMENT_1).template data<T>();
257  const auto* moment2In = Input(MOMENT_2).template data<T>();
258  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
259  auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
260  auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
261 
262  for (auto i = 0; i < n; ++i) {
263  auto idx = indices[i];
264 
265  if (block_size == 1) {
266  float gi = gradIn[i];
267  float mi = moment1Out[idx] =
268  moment1In[idx] * beta1_ + gi * (1 - beta1_);
269  float vi = moment2Out[idx] =
270  moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
271  paramOut[idx] =
272  paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
273 
274  } else {
275  auto offsetI = i * block_size;
276  auto offsetIdx = idx * block_size;
277 
278 #ifndef NDEBUG
279  CAFFE_ENFORCE_GE(
280  Input(PARAM).size(),
281  block_size + offsetIdx,
282  this->debug_def().input(PARAM),
283  ", out of bound, idx:",
284  idx,
285  " for input i:",
286  i,
287  " and block size:",
288  block_size);
289  CAFFE_ENFORCE_GE(
290  Input(GRAD).size(),
291  block_size + offsetI,
292  this->debug_def().input(GRAD),
293  ", out of bound idx, idx:",
294  idx,
295  " for input i:",
296  i);
297 #endif
298 
299  const float* w = paramIn + offsetIdx;
300  const float* g = gradIn + offsetI;
301  const float* m1 = moment1In + offsetIdx;
302  const float* m2 = moment2In + idx;
303  float* nw = paramOut + offsetIdx;
304  float* nm1 = moment1Out + offsetIdx;
305  float* nm2 = moment2Out + idx;
306 
307  float m2_sum = 0.;
308  for (auto j = 0; j < block_size; ++j) {
309  float gj = g[j];
310  m2_sum += gj * gj;
311  }
312  float vi = nm2[0] =
313  m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
314  for (auto j = 0; j < block_size; ++j) {
315  float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
316  nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
317  }
318  }
319  }
320  return true;
321  }
322 
323  protected:
324  T beta1_;
325  T beta2_;
326  T epsilon_;
327  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
328  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
329 };
330 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...