Caffe2 - C++ API
A deep learning, cross platform ML framework
learning_rate_op.h
1 #ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_
2 #define CAFFE2_SGD_LEARNING_RATE_OP_H_
3 
4 #include <cfloat>
5 #include <cmath>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/sgd/learning_rate_functors.h"
9 
10 namespace caffe2 {
11 
12 template <typename T, class Context>
13 class LearningRateOp final : public Operator<Context> {
14  public:
15  LearningRateOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  functor_(nullptr),
18  base_lr_(OperatorBase::template GetSingleArgument<float>(
19  "base_lr",
20  FLT_MAX)) {
21  CAFFE_ENFORCE_NE(base_lr_, FLT_MAX, "Base learning rate must be set.");
22  const string policy = OperatorBase::GetSingleArgument<string>("policy", "");
23  CAFFE_ENFORCE(policy.size(), "Must specify a learning rate policy.");
24  if (policy == "fixed") {
25  functor_.reset(new FixedLearningRate<T>());
26  } else if (policy == "alter") {
27  bool active_first =
28  OperatorBase::template GetSingleArgument<bool>("active_first", true);
29  int64_t active_period = OperatorBase::template GetSingleArgument<int64_t>(
30  "active_period", -1);
31  int64_t inactive_period =
32  OperatorBase::template GetSingleArgument<int64_t>(
33  "inactive_period", -1);
34  DCHECK_GE(active_period, 0);
35  DCHECK_GE(inactive_period, 0);
36  functor_.reset(new AlternateLearningRate<T>(
37  active_period, inactive_period, active_first));
38  } else if (policy == "hill") {
39  int64_t num_iter =
40  OperatorBase::template GetSingleArgument<int>("num_iter", 0);
41  DCHECK_GT(num_iter, 0);
42  T start_multiplier = OperatorBase::template GetSingleArgument<float>(
43  "start_multiplier", 0.);
44  DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1]
45  DCHECK_LE(start_multiplier, 1);
46  T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
47  DCHECK_GT(gamma, 0);
48  T power = OperatorBase::template GetSingleArgument<float>("power", 0);
49  DCHECK_GT(power, 0);
50  T end_multiplier =
51  OperatorBase::template GetSingleArgument<float>("end_multiplier", 0);
52  DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1]
53  DCHECK_LE(end_multiplier, 1);
54  functor_.reset(new HillLearningRate<T>(
55  num_iter, start_multiplier, gamma, power, end_multiplier));
56  } else if (policy == "step") {
57  int stepsize =
58  OperatorBase::template GetSingleArgument<int>("stepsize", 0);
59  T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
60  DCHECK_GT(stepsize, 0);
61  DCHECK_GT(gamma, 0);
62  functor_.reset(new StepLearningRate<T>(stepsize, gamma));
63  } else if (policy == "exp") {
64  T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
65  DCHECK_GT(gamma, 0);
66  functor_.reset(new ExpLearningRate<T>(gamma));
67  } else if (policy == "inv") {
68  T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
69  T power = OperatorBase::template GetSingleArgument<float>("power", 0);
70  DCHECK_GT(gamma, 0);
71  DCHECK_GT(power, 0);
72  functor_.reset(new InvLearningRate<T>(gamma, power));
73  } else if (policy == "poly") {
74  int max_iter = OperatorBase::template GetSingleArgument<int>("max_iter", -1);
75  T power = OperatorBase::template GetSingleArgument<float>("power", 0);
76  DCHECK_GT(power, 0);
77  functor_.reset(new PolyLearningRate<T>(power, max_iter));
78  } else if (policy == "linearWarmup") {
79  T start_multiplier = OperatorBase::template GetSingleArgument<float>(
80  "start_multiplier", 0.);
81  int num_iter =
82  OperatorBase::template GetSingleArgument<int>("num_iter", 0);
83  DCHECK_GT(start_multiplier, 0);
84  functor_.reset(
85  new LinearWarmupLearningRate<T>(start_multiplier, num_iter));
86  } else if (policy == "constantWarmup") {
87  T multiplier =
88  OperatorBase::template GetSingleArgument<float>("multiplier", 0.5);
89  int num_iter =
90  OperatorBase::template GetSingleArgument<int>("num_iter", 0);
91  DCHECK_GT(multiplier, 0);
92  functor_.reset(new ConstantWarmupLearningRate<T>(multiplier, num_iter));
93  } else {
94  LOG(FATAL) << "Unknown learning rate policy: " << policy;
95  }
96  }
97  USE_OPERATOR_CONTEXT_FUNCTIONS;
98 
99  bool RunOnDevice() override {
100  int64_t iter =
101  OperatorBase::Input<TensorCPU>(0).template data<int64_t>()[0];
102  T learning_rate = base_lr_ * (*functor_)(iter);
103  // Write to output.
104  auto* output = Output(0);
105  output->Resize(vector<TIndex>());
106  context_.template Copy<T, CPUContext, Context>(
107  1, &learning_rate, Output(0)->template mutable_data<T>());
108  return true;
109  }
110 
111  private:
112  unique_ptr<LearningRateFunctor<T> > functor_;
113  T base_lr_;
114 
115 };
116 
117 } // namespace caffe2
118 
119 #endif // CAFFE2_SGD_LEARNING_RATE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...