1 #ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_ 2 #define CAFFE2_SGD_LEARNING_RATE_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/sgd/learning_rate_functors.h" 12 template <
typename T,
class Context>
18 base_lr_(OperatorBase::template GetSingleArgument<float>(
21 CAFFE_ENFORCE_NE(base_lr_, FLT_MAX,
"Base learning rate must be set.");
22 const string policy = OperatorBase::GetSingleArgument<string>(
"policy",
"");
23 CAFFE_ENFORCE(policy.size(),
"Must specify a learning rate policy.");
24 if (policy ==
"fixed") {
26 }
else if (policy ==
"alter") {
28 OperatorBase::template GetSingleArgument<bool>(
"active_first",
true);
29 int64_t active_period = OperatorBase::template GetSingleArgument<int64_t>(
31 int64_t inactive_period =
32 OperatorBase::template GetSingleArgument<int64_t>(
33 "inactive_period", -1);
34 DCHECK_GE(active_period, 0);
35 DCHECK_GE(inactive_period, 0);
37 active_period, inactive_period, active_first));
38 }
else if (policy ==
"hill") {
40 OperatorBase::template GetSingleArgument<int>(
"num_iter", 0);
41 DCHECK_GT(num_iter, 0);
42 T start_multiplier = OperatorBase::template GetSingleArgument<float>(
43 "start_multiplier", 0.);
44 DCHECK_GE(start_multiplier, 0);
45 DCHECK_LE(start_multiplier, 1);
46 T gamma = OperatorBase::template GetSingleArgument<float>(
"gamma", 0);
48 T power = OperatorBase::template GetSingleArgument<float>(
"power", 0);
51 OperatorBase::template GetSingleArgument<float>(
"end_multiplier", 0);
52 DCHECK_GE(end_multiplier, 0);
53 DCHECK_LE(end_multiplier, 1);
55 num_iter, start_multiplier, gamma, power, end_multiplier));
56 }
else if (policy ==
"step") {
58 OperatorBase::template GetSingleArgument<int>(
"stepsize", 0);
59 T gamma = OperatorBase::template GetSingleArgument<float>(
"gamma", 0);
60 DCHECK_GT(stepsize, 0);
63 }
else if (policy ==
"exp") {
64 T gamma = OperatorBase::template GetSingleArgument<float>(
"gamma", 0);
67 }
else if (policy ==
"inv") {
68 T gamma = OperatorBase::template GetSingleArgument<float>(
"gamma", 0);
69 T power = OperatorBase::template GetSingleArgument<float>(
"power", 0);
73 }
else if (policy ==
"poly") {
74 int max_iter = OperatorBase::template GetSingleArgument<int>(
"max_iter", -1);
75 T power = OperatorBase::template GetSingleArgument<float>(
"power", 0);
78 }
else if (policy ==
"linearWarmup") {
79 T start_multiplier = OperatorBase::template GetSingleArgument<float>(
80 "start_multiplier", 0.);
82 OperatorBase::template GetSingleArgument<int>(
"num_iter", 0);
83 DCHECK_GT(start_multiplier, 0);
86 }
else if (policy ==
"constantWarmup") {
88 OperatorBase::template GetSingleArgument<float>(
"multiplier", 0.5);
90 OperatorBase::template GetSingleArgument<int>(
"num_iter", 0);
91 DCHECK_GT(multiplier, 0);
94 LOG(FATAL) <<
"Unknown learning rate policy: " << policy;
97 USE_OPERATOR_CONTEXT_FUNCTIONS;
99 bool RunOnDevice()
override {
101 OperatorBase::Input<TensorCPU>(0).
template data<int64_t>()[0];
102 T learning_rate = base_lr_ * (*functor_)(iter);
104 auto* output = Output(0);
105 output->Resize(vector<TIndex>());
106 context_.template Copy<T, CPUContext, Context>(
107 1, &learning_rate, Output(0)->template mutable_data<T>());
112 unique_ptr<LearningRateFunctor<T> > functor_;
119 #endif // CAFFE2_SGD_LEARNING_RATE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...