3 #include "caffe2/core/operator.h" 7 template <
typename Context>
19 for (
auto i = 0; i < N; ++i) {
21 float hi = nh[i] = decay * h[i] + gi * gi;
22 nw[i] = w[i] + lr[0] * gi / (std::sqrt(hi) + epsilon);
26 template <
typename T,
class Context>
29 USE_OPERATOR_CONTEXT_FUNCTIONS;
32 epsilon_(OperatorBase::GetSingleArgument<T>(
"epsilon", 1e-5f)),
33 decay_(OperatorBase::GetSingleArgument<T>(
"decay", 1.0f)) {}
35 bool RunOnDevice()
override {
36 CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENT_1).size());
37 CAFFE_ENFORCE(Input(GRAD).size() == Input(PARAM).size());
38 Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
39 Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
40 adagrad_update<Context>(
42 Input(PARAM).template data<T>(),
43 Input(GRAD).template data<T>(),
44 Input(MOMENT_1).template data<T>(),
45 Output(OUTPUT_PARAM)->template mutable_data<T>(),
46 Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
49 Input(LR).template data<T>(),
57 INPUT_TAGS(PARAM, MOMENT_1, GRAD, LR);
58 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
61 template <
typename T,
class Context>
64 USE_OPERATOR_CONTEXT_FUNCTIONS;
67 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
69 bool RunOnDevice()
override {
71 CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_1).size());
72 CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
74 Input(PARAM).size_from_dim(1),
75 Input(GRAD).size_from_dim(Input(INDICES).ndim()));
78 this, Input(INDICES));
81 template <
typename SIndex>
82 bool DoRunWithType() {
83 const auto* lr = Input(LR).template data<T>();
84 const auto* indices = Input(INDICES).template data<SIndex>();
85 const auto* gradIn = Input(GRAD).template data<T>();
86 const auto* paramIn = Input(PARAM).template data<T>();
87 const auto* momentIn = Input(MOMENT_1).template data<T>();
88 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
89 auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
91 auto n = Input(INDICES).size();
96 auto block_size = Input(GRAD).size() / n;
97 for (
auto i = 0; i < n; ++i) {
98 auto idx = indices[i];
99 if (block_size == 1) {
100 float gi = gradIn[i];
101 float hi = momentOut[idx] = momentIn[idx] + gi * gi;
102 paramOut[idx] = paramIn[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_);
104 auto offsetI = i * block_size;
105 auto offsetIdx = idx * block_size;
110 block_size + offsetIdx,
111 this->debug_def().input(PARAM),
112 ", out of bound, idx:",
120 block_size + offsetI,
121 this->debug_def().input(GRAD),
122 ", out of bound idx, idx:",
131 momentIn + offsetIdx,
132 paramOut + offsetIdx,
133 momentOut + offsetIdx,
145 INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR);
146 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
149 template <
typename T,
class Context>
152 USE_OPERATOR_CONTEXT_FUNCTIONS;
155 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
157 bool RunOnDevice()
override {
159 CAFFE_ENFORCE_EQ(Input(PARAM).dims()[0], Input(MOMENT_1).size());
160 CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
162 Input(PARAM).size_from_dim(1),
163 Input(GRAD).size_from_dim(Input(INDICES).ndim()));
166 this, Input(INDICES));
169 template <
typename SIndex>
170 bool DoRunWithType() {
171 const auto* lr = Input(LR).template data<T>();
172 const auto* indices = Input(INDICES).template data<SIndex>();
173 const auto* gradIn = Input(GRAD).template data<T>();
174 const auto* paramIn = Input(PARAM).template data<T>();
175 const auto* momentIn = Input(MOMENT_1).template data<T>();
176 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
177 auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
179 auto n = Input(INDICES).size();
184 auto block_size = Input(GRAD).size() / n;
186 for (
auto i = 0; i < n; ++i) {
187 auto idx = indices[i];
188 if (block_size == 1) {
189 float gi = gradIn[i];
190 float hi = momentOut[idx] = momentIn[idx] + gi * gi;
191 paramOut[idx] = paramIn[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_);
193 auto offsetI = i * block_size;
194 auto offsetIdx = idx * block_size;
199 block_size + offsetIdx,
200 this->debug_def().input(PARAM),
201 ", out of bound, idx:",
209 block_size + offsetI,
210 this->debug_def().input(GRAD),
211 ", out of bound idx, idx:",
217 const float* w = paramIn + offsetIdx;
218 const float* g = gradIn + offsetI;
219 const float* h = momentIn + idx;
220 float* nw = paramOut + offsetIdx;
221 float* nh = momentOut + idx;
223 for (
auto j = 0; j < block_size; ++j) {
227 float hi = nh[0] = h[0] + hs / block_size;
228 float step = lr[0] / (std::sqrt(hi) + epsilon_);
229 for (
auto j = 0; j < block_size; ++j) {
230 nw[j] = w[j] + g[j] * step;
239 INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR);
240 OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...