Caffe2 - C++ API
A deep learning, cross platform ML framework
lars_op.h
1 #ifndef CAFFE2_OPERATORS_LARS_OP_H_
2 #define CAFFE2_OPERATORS_LARS_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class LarsOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  LarsOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  offset_(OperatorBase::GetSingleArgument<float>("offset", 0.5)) {}
17 
18  bool RunOnDevice() override {
19  auto& X = Input(0);
20  auto& dX = Input(1);
21  CAFFE_ENFORCE(
22  dX.size() == X.size(), "Gradient size doesn't match parameter size.");
23  CAFFE_ENFORCE_GE(offset_, 0);
24 
25  auto* lr_rescale = Output(0);
26  lr_rescale->Resize(vector<TIndex>{1});
27 
28  Compute(
29  dX.size(),
30  X.template data<T>(),
31  dX.template data<T>(),
32  offset_,
33  lr_rescale->template mutable_data<T>());
34 
35  return true;
36  }
37 
38  private:
39  void Compute(
40  TIndex N,
41  const T* X_data,
42  const T* dX_data,
43  T offset,
44  T* lr_rescale_data);
45 
46  T offset_;
47 };
48 
49 } // namespace caffe2
50 
51 #endif // CAFFE2_OPERATORS_LARS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...