Caffe2 - C++ API
A deep learning, cross platform ML framework
rmsprop_op.cc
1 #include "rmsprop_op.h"
2 
3 #include "caffe2/utils/math.h"
4 
5 namespace caffe2 {
6 
7 template <>
8 void rmsprop_update<CPUContext>(
9  int N,
10  const float* g,
11  const float* ms,
12  const float* mom,
13  float* ng,
14  float* nms,
15  float* nmom,
16  float decay,
17  float momentum,
18  float epsilon,
19  const float* lr,
20  CPUContext* /*context*/) {
21  ConstEigenVectorArrayMap<float> gVec(g, N);
22  ConstEigenVectorArrayMap<float> msVec(ms, N);
23  ConstEigenVectorArrayMap<float> momVec(mom, N);
24  // Update new mean square estimate
25  EigenVectorArrayMap<float> nmsVec(nms, N);
26  nmsVec = msVec + (1.0f - decay) * (gVec * gVec - msVec);
27  // Update momentum estimate
28  EigenVectorArrayMap<float> nmomVec(nmom, N);
29  nmomVec = momVec * momentum + lr[0] * gVec / (epsilon + nmsVec).sqrt();
30  // New gradient is the momentum
31  EigenVectorArrayMap<float>(ng, N) = nmomVec;
32 }
33 
34 REGISTER_CPU_OPERATOR(RmsProp, RmsPropOp<float, CPUContext>);
35 OPERATOR_SCHEMA(RmsProp)
36  .NumInputs(4)
37  .NumOutputs(3)
38  .AllowInplace({{0, 0}, {1, 1}, {2, 2}})
39  .SetDoc(R"DOC(
40 Computes the RMSProp update
41 (http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
42 Concretely, given inputs (grad, mean_squares, mom, lr), computes:
43 
44  mean_squares_o = mean_squares + (1 - decay) * (square(grad) - mean_squares)
45  mom_o = momentum * mom + lr * grad / sqrt(epsilon + mean_squares_o)
46  grad_o = mom_o
47 
48 Returns (grad_o, mean_squares_o, mom_o).
49 )DOC");
50 SHOULD_NOT_DO_GRADIENT(RmsProp);
51 
52 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...