3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/timer.h" 8 template <
class Context>
9 void fp16_momentum_sgd_update(
23 template <
typename T,
class Context>
26 USE_OPERATOR_CONTEXT_FUNCTIONS;
29 momentum_(OperatorBase::GetSingleArgument<float>(
"momentum", 0.0)),
31 OperatorBase::GetSingleArgument<float>(
"weight_decay", 0.0)),
32 nesterov_(OperatorBase::GetSingleArgument<int>(
"nesterov", 0)),
35 fp32_update_(OperatorBase::GetSingleArgument<int>(
"fp32_update", 0)) {}
37 bool RunOnDevice()
override {
41 CAFFE_ENFORCE(Input(LR).size() == 1);
42 CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
43 Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
44 Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
46 fp16_momentum_sgd_update<Context>(
48 Input(GRAD).template data<T>(),
49 Input(MOMENTUM).template data<T>(),
50 Output(OUTPUT_GRAD)->template mutable_data<T>(),
51 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
52 Input(LR).template data<float>(),
57 Output(OUTPUT_PARAM)->template mutable_data<T>(),
65 float weight_decay_{0.0};
68 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
69 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...