Caffe2 - C++ API
A deep learning, cross platform ML framework
spatial_batch_norm_op.h
1 #ifndef CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
2 #define CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <class Context>
11 class SpatialBNOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  SpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
17  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)),
18  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.9f)),
19  order_(StringToStorageOrder(
20  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
21  num_batches_(OperatorBase::GetSingleArgument<int>("num_batches", 1)) {
22  // TODO(jiayq): update the input and output size checks.
23  CAFFE_ENFORCE(
24  (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 5));
25  CAFFE_ENFORCE_GT(epsilon_, 0);
26  CAFFE_ENFORCE_GE(momentum_, 0);
27  CAFFE_ENFORCE_LE(momentum_, 1);
28  }
29  ~SpatialBNOp() {}
30 
31  bool RunOnDevice() override {
32  return true;
33  }
34 
35  protected:
36  bool is_test_;
37  double epsilon_;
38  double momentum_;
39  StorageOrder order_;
40  int num_batches_;
41  INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR, SUMS, SUMSQ);
42  OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_INV_VAR);
43 };
44 
45 template <class Context>
46 class SpatialBNGradientOp : public Operator<Context> {
47  public:
48  USE_OPERATOR_CONTEXT_FUNCTIONS;
49  SpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws)
50  : Operator<Context>(operator_def, ws),
51  is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
52  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)),
53  order_(StringToStorageOrder(
54  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
55  num_batches_(OperatorBase::GetSingleArgument<int>("num_batches", 1)) {
56  CAFFE_ENFORCE(InputSize() == 5 || InputSize() == 7);
57  CAFFE_ENFORCE(OutputSize() == 3);
58  }
60 
61  bool RunOnDevice() override {
62  return true;
63  }
64 
65  protected:
66  bool is_test_;
67  double epsilon_;
68  StorageOrder order_;
69  int num_batches_;
70 
71  INPUT_TAGS(
72  INPUT,
73  SCALE,
74  OUTPUT_GRAD,
75  SAVED_MEAN,
76  SAVED_INV_VAR,
77  AGGREGATE_SCALE_GRAD,
78  AGGREGATE_BIAS_GRAD);
79  OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
80 };
81 
82 } // namespace caffe2
83 
84 #endif // CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...