1 #ifndef CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_ 2 #define CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 15 virtual T operator()(
const int64_t iter)
const = 0;
22 T operator()(
const int64_t )
const override {
34 const int64_t active_period,
35 const int64_t inactive_period,
36 const bool active_first)
37 : active_period_(active_period),
38 inactive_period_(inactive_period),
39 active_first_(active_first) {}
40 T operator()(
const int64_t iter)
const override {
41 if (iter % (active_period_ + inactive_period_) <
42 (active_first_ ? active_period_ : inactive_period_)) {
43 return active_first_ ? 1. : 0.;
45 return active_first_ ? 0. : 1.;
49 int64_t active_period_;
50 int64_t inactive_period_;
59 : stepsize_(stepsize), gamma_(gamma) {}
60 T operator()(
const int64_t iter)
const override {
61 return std::pow(gamma_, static_cast<T>(iter / stepsize_));
73 T operator()(
const int64_t iter)
const override {
74 return std::pow(gamma_, static_cast<T>(iter));
85 : gamma_(gamma), power_(power) {}
86 T operator()(
const int64_t iter)
const override {
87 return std::pow(T(1) + gamma_ * iter, -power_);
98 : power_(power), max_iter_(max_iter) {}
99 T operator()(
const int64_t iter)
const override {
100 return std::pow(1 - T(iter) / T(max_iter_), power_);
107 template <
typename T>
111 : start_multiplier_(start_multiplier), num_iter_(num_iter) {}
112 T operator()(
const int64_t iter)
const override {
113 if (iter >= num_iter_) {
116 return start_multiplier_ + (1. - start_multiplier_) * T(iter) / T(num_iter_);
123 template <
typename T>
127 : multiplier_(multiplier), num_iter_(num_iter) {}
128 T operator()(
const int64_t iter)
const override {
129 if (iter >= num_iter_) {
132 return T(multiplier_);
142 template <
typename T>
146 const int64_t num_iter,
147 const T start_multiplier,
150 const T end_multiplier)
151 : linear_warmup_lr_(start_multiplier, num_iter),
152 inv_lr_(gamma, power),
154 end_multiplier_(end_multiplier) {}
155 T operator()(
const int64_t iter)
const override {
156 if (iter < num_iter_) {
157 return linear_warmup_lr_(iter);
159 return std::max(end_multiplier_, inv_lr_(iter - num_iter_));
170 #endif // CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...