3 #include "caffe2/core/operator.h" 7 template <
typename Context>
8 void momentum_sgd_update(
19 const float LR = lr[0];
20 for (
auto i = 0; i < N; ++i) {
22 const float adjusted_gradient = LR * g[i] + momentum * m[i];
23 nm[i] = adjusted_gradient;
24 ng[i] = adjusted_gradient;
26 const float mi = m[i];
27 const float mi_new = momentum * mi + LR * g[i];
29 ng[i] = (1 + momentum) * mi_new - momentum * mi;
38 template <
typename T,
class Context>
41 USE_OPERATOR_CONTEXT_FUNCTIONS;
44 momentum_(OperatorBase::GetSingleArgument<T>(
"momentum", 0.0)),
45 nesterov_(OperatorBase::GetSingleArgument<int>(
"nesterov", 0)) {}
47 bool RunOnDevice()
override {
51 CAFFE_ENFORCE(Input(LR).size() == 1);
52 CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
53 Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
54 Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
56 momentum_sgd_update<Context>(
58 Input(GRAD).template data<T>(),
59 Input(MOMENTUM).template data<T>(),
60 Output(OUTPUT_GRAD)->template mutable_data<T>(),
61 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
62 Input(LR).template data<T>(),
73 INPUT_TAGS(GRAD, MOMENTUM, LR);
74 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM);
77 template <
typename T,
class Context>
80 USE_OPERATOR_CONTEXT_FUNCTIONS;
83 momentum_(OperatorBase::GetSingleArgument<T>(
"momentum", 0.0)),
84 nesterov_(OperatorBase::GetSingleArgument<int>(
"nesterov", 0)) {}
86 bool RunOnDevice()
override {
90 CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
91 CAFFE_ENFORCE_EQ(Input(GRAD).size(), Input(MOMENTUM).size());
92 Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
93 Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
95 momentum_sgd_update<Context>(
97 Input(GRAD).template data<T>(),
98 Input(MOMENTUM).template data<T>(),
99 Output(OUTPUT_GRAD)->template mutable_data<T>(),
100 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
101 Input(LR).template data<T>(),
104 Output(OUTPUT_PARAM)->template mutable_data<T>(),
112 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
113 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
116 template <
typename T,
class Context>
119 USE_OPERATOR_CONTEXT_FUNCTIONS;
122 momentum_(OperatorBase::GetSingleArgument<T>(
"momentum", 0.0)),
123 nesterov_(OperatorBase::GetSingleArgument<int>(
"nesterov", 0)) {}
125 bool RunOnDevice()
override {
127 Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
130 CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
131 CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENTUM).size());
132 CAFFE_ENFORCE_EQ(Input(PARAM).size_from_dim(1),
133 Input(GRAD).size_from_dim(Input(INDICES).ndim()));
136 this, Input(INDICES));
139 template <
typename SIndex>
140 bool DoRunWithType() {
141 auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
142 auto n = Input(GRAD).size() / block_size;
144 const auto* gradIn = Input(GRAD).template data<T>();
145 const auto* momentumIn = Input(MOMENTUM).template data<T>();
146 const auto* lr = Input(LR).template data<T>();
147 const auto* paramIn = Input(PARAM).template data<T>();
148 const auto* indices = Input(INDICES).template data<SIndex>();
150 auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
151 auto* momentumOut = Output(OUTPUT_MOMENTUM)->template mutable_data<T>();
152 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
154 for (
auto i = 0; i < n; ++i) {
155 auto idx = indices[i];
156 auto offsetI = i * block_size;
157 auto offsetIdx = idx * block_size;
159 CAFFE_ENFORCE(offsetIdx + block_size <= Input(PARAM).size());
160 CAFFE_ENFORCE(offsetI + block_size <= Input(GRAD).size());
162 momentum_sgd_update<Context>(
165 momentumIn + offsetIdx,
167 momentumOut + offsetIdx,
171 paramOut + offsetIdx,
180 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM, INDICES);
181 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...