3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/timer.h" 8 template <
class Context>
9 void fp32_momentum_sgd_update(
22 template <
typename T,
class Context>
25 USE_OPERATOR_CONTEXT_FUNCTIONS;
28 momentum_(OperatorBase::GetSingleArgument<float>(
"momentum", 0.0)),
30 OperatorBase::GetSingleArgument<float>(
"weight_decay", 0.0)),
31 nesterov_(OperatorBase::GetSingleArgument<int>(
"nesterov", 0)) {}
33 bool RunOnDevice()
override {
37 CAFFE_ENFORCE(Input(LR).size() == 1);
38 CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
39 Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
40 Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
42 fp32_momentum_sgd_update<Context>(
44 Input(GRAD).template data<T>(),
45 Input(MOMENTUM).template data<T>(),
46 Output(OUTPUT_GRAD)->template mutable_data<T>(),
47 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
48 Input(LR).template data<float>(),
52 Output(OUTPUT_PARAM)->template mutable_data<T>(),
60 float weight_decay_{0.0};
62 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
63 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...