Caffe2 - C++ API
A deep learning, cross platform ML framework
lpnorm_op.h
1 #ifndef CAFFE2_OPERATORS_LPNORM_OP_H_
2 #define CAFFE2_OPERATORS_LPNORM_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class LpNormOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  LpNormOp(const OperatorDef& def, Workspace* ws)
15  : Operator<Context>(def, ws),
16  p_(OperatorBase::GetSingleArgument<int>("p", 2)),
17  average_(OperatorBase::GetSingleArgument<bool>("average", false)) {
18  CAFFE_ENFORCE(p_ == 1 || p_ == 2, "p should be either 1 or 2.");
19  }
20 
21  bool RunOnDevice() override;
22 
23  protected:
24  int p_;
25  bool average_;
26  INPUT_TAGS(X_IN);
27  OUTPUT_TAGS(OUT);
28  // Input: X; Output: Norm
29 };
30 
31 template <typename T, class Context>
32 class LpNormGradientOp : public Operator<Context> {
33  public:
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35  LpNormGradientOp(const OperatorDef& def, Workspace* ws)
36  : Operator<Context>(def, ws),
37  p_(OperatorBase::GetSingleArgument<int>("p", 2)),
38  average_(OperatorBase::GetSingleArgument<bool>("average", false)) {
39  CAFFE_ENFORCE(p_ == 1 || p_ == 2, "p should be either 1 or 2.");
40  }
41 
42  bool RunOnDevice() override;
43 
44  protected:
45  int p_;
46  bool average_;
47  INPUT_TAGS(X_IN, DER_NORM_IN);
48  OUTPUT_TAGS(DER_X_OUT);
49  // Input: X, dNorm; Output: dX
50 };
51 
52 } // namespace caffe2
53 
54 #endif // CAFFE2_OPERATORS_LPNORM_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...