1 #include "rmsprop_op.h" 3 #include "caffe2/utils/math.h" 8 void rmsprop_update<CPUContext>(
21 ConstEigenVectorArrayMap<float> gVec(g, N);
22 ConstEigenVectorArrayMap<float> msVec(ms, N);
23 ConstEigenVectorArrayMap<float> momVec(mom, N);
25 EigenVectorArrayMap<float> nmsVec(nms, N);
26 nmsVec = msVec + (1.0f - decay) * (gVec * gVec - msVec);
28 EigenVectorArrayMap<float> nmomVec(nmom, N);
29 nmomVec = momVec * momentum + lr[0] * gVec / (epsilon + nmsVec).sqrt();
31 EigenVectorArrayMap<float>(ng, N) = nmomVec;
34 REGISTER_CPU_OPERATOR(RmsProp, RmsPropOp<float, CPUContext>);
35 OPERATOR_SCHEMA(RmsProp)
38 .AllowInplace({{0, 0}, {1, 1}, {2, 2}})
40 Computes the RMSProp update 41 (http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). 42 Concretely, given inputs (grad, mean_squares, mom, lr), computes: 44 mean_squares_o = mean_squares + (1 - decay) * (square(grad) - mean_squares) 45 mom_o = momentum * mom + lr * grad / sqrt(epsilon + mean_squares_o) 48 Returns (grad_o, mean_squares_o, mom_o). 50 SHOULD_NOT_DO_GRADIENT(RmsProp); A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...