3 #include "caffe2/core/operator.h" 7 template <
typename Context>
22 for (
auto i = 0; i < N; ++i) {
24 float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
25 float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
26 ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
30 template <
typename Context>
46 for (
auto i = 0; i < N; ++i) {
48 float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
49 float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
50 float ng = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
55 template <
typename T,
class Context>
58 USE_OPERATOR_CONTEXT_FUNCTIONS;
61 beta1_(OperatorBase::GetSingleArgument<float>(
"beta1", 0.9f)),
62 beta2_(OperatorBase::GetSingleArgument<float>(
"beta2", 0.999f)),
63 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
64 bool RunOnDevice()
override {
66 CAFFE_ENFORCE(OperatorBase::InputIsType<TensorCPU>(ITER));
67 CAFFE_ENFORCE(Input(LR).size() == 1);
68 CAFFE_ENFORCE(Input(GRAD).size() == Input(PARAM).size());
69 CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENT_1).size());
70 CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENT_2).size());
71 Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
72 Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
73 Output(OUTPUT_MOMENT_2)->ResizeLike(Input(MOMENT_2));
76 OperatorBase::Input<TensorCPU>(ITER).
template data<int64_t>()[0];
78 const auto t = iter + 1;
79 const auto correction =
80 std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
81 adam_compute<Context>(
83 Input(PARAM).template data<T>(),
84 Input(GRAD).template data<T>(),
85 Input(MOMENT_1).template data<T>(),
86 Input(MOMENT_2).template data<T>(),
87 Output(OUTPUT_PARAM)->template mutable_data<T>(),
88 Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
89 Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
94 Input(LR).template data<T>(),
103 INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
104 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
107 template <
typename T,
class Context>
110 USE_OPERATOR_CONTEXT_FUNCTIONS;
113 beta1_(OperatorBase::GetSingleArgument<float>(
"beta1", 0.9f)),
114 beta2_(OperatorBase::GetSingleArgument<float>(
"beta2", 0.999f)),
115 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
117 bool RunOnDevice()
override {
119 CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_1).size());
120 CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_2).size());
122 Input(PARAM).size_from_dim(1),
123 Input(GRAD).size_from_dim(Input(INDICES).ndim()));
124 CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
127 this, Input(INDICES));
130 template <
typename SIndex>
131 bool DoRunWithType() {
132 const auto* lr = Input(LR).template data<T>();
134 OperatorBase::Input<TensorCPU>(ITER).
template data<int64_t>()[0];
136 const auto t = iter + 1;
137 const auto correction =
138 std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
140 auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
141 auto n = Input(GRAD).size() / block_size;
143 const auto* paramIn = Input(PARAM).template data<T>();
144 const auto* indices = Input(INDICES).template data<SIndex>();
145 const auto* gradIn = Input(GRAD).template data<T>();
146 const auto* moment1In = Input(MOMENT_1).template data<T>();
147 const auto* moment2In = Input(MOMENT_2).template data<T>();
148 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
149 auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
150 auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
152 for (
auto i = 0; i < n; ++i) {
153 auto idx = indices[i];
155 if (block_size == 1) {
156 float gi = gradIn[i];
157 float mi = moment1Out[idx] =
158 moment1In[idx] * beta1_ + gi * (1 - beta1_);
159 float vi = moment2Out[idx] =
160 moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
162 paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
165 auto offsetI = i * block_size;
166 auto offsetIdx = idx * block_size;
171 block_size + offsetIdx,
172 this->debug_def().input(PARAM),
173 ", out of bound, idx:",
181 block_size + offsetI,
182 this->debug_def().input(GRAD),
183 ", out of bound idx, idx:",
193 moment1In + offsetIdx,
194 moment2In + offsetIdx,
195 paramOut + offsetIdx,
196 moment1Out + offsetIdx,
197 moment2Out + offsetIdx,
213 INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
214 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
217 template <
typename T,
class Context>
220 USE_OPERATOR_CONTEXT_FUNCTIONS;
223 beta1_(OperatorBase::GetSingleArgument<float>(
"beta1", 0.9f)),
224 beta2_(OperatorBase::GetSingleArgument<float>(
"beta2", 0.999f)),
225 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
227 bool RunOnDevice()
override {
229 CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_1).size());
230 CAFFE_ENFORCE_EQ(Input(PARAM).dims()[0], Input(MOMENT_2).size());
232 Input(PARAM).size_from_dim(1),
233 Input(GRAD).size_from_dim(Input(INDICES).ndim()));
234 CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
237 this, Input(INDICES));
240 template <
typename SIndex>
241 bool DoRunWithType() {
242 const auto* lr = Input(LR).template data<T>();
244 OperatorBase::Input<TensorCPU>(ITER).
template data<int64_t>()[0];
246 const auto t = iter + 1;
247 const auto correction =
248 std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
250 auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
251 auto n = Input(GRAD).size() / block_size;
253 const auto* paramIn = Input(PARAM).template data<T>();
254 const auto* indices = Input(INDICES).template data<SIndex>();
255 const auto* gradIn = Input(GRAD).template data<T>();
256 const auto* moment1In = Input(MOMENT_1).template data<T>();
257 const auto* moment2In = Input(MOMENT_2).template data<T>();
258 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
259 auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
260 auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
262 for (
auto i = 0; i < n; ++i) {
263 auto idx = indices[i];
265 if (block_size == 1) {
266 float gi = gradIn[i];
267 float mi = moment1Out[idx] =
268 moment1In[idx] * beta1_ + gi * (1 - beta1_);
269 float vi = moment2Out[idx] =
270 moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
272 paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
275 auto offsetI = i * block_size;
276 auto offsetIdx = idx * block_size;
281 block_size + offsetIdx,
282 this->debug_def().input(PARAM),
283 ", out of bound, idx:",
291 block_size + offsetI,
292 this->debug_def().input(GRAD),
293 ", out of bound idx, idx:",
299 const float* w = paramIn + offsetIdx;
300 const float* g = gradIn + offsetI;
301 const float* m1 = moment1In + offsetIdx;
302 const float* m2 = moment2In + idx;
303 float* nw = paramOut + offsetIdx;
304 float* nm1 = moment1Out + offsetIdx;
305 float* nm2 = moment2Out + idx;
308 for (
auto j = 0; j < block_size; ++j) {
313 m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
314 for (
auto j = 0; j < block_size; ++j) {
315 float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
316 nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
327 INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
328 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...