Caffe2 - C++ API
A deep learning, cross platform ML framework
instance_norm_op.h
1 #ifndef CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
2 #define CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class InstanceNormOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  InstanceNormOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  epsilon_(OperatorBase::GetSingleArgument<T>("epsilon", 1e-5f)),
17  order_(StringToStorageOrder(
18  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
19  CAFFE_ENFORCE(epsilon_ >= 0, "Must pass a nonnegative epsilon.");
20  }
21  ~InstanceNormOp() {}
22 
23  bool RunOnDevice() {
24  switch (order_) {
25  case StorageOrder::NHWC:
26  return RunOnDeviceWithOrderNHWC();
27  case StorageOrder::NCHW:
28  return RunOnDeviceWithOrderNCHW();
29  default:
30  CAFFE_THROW("Unknown storage order: ", order_);
31  }
32  }
33 
34  bool RunOnDeviceWithOrderNHWC();
35  bool RunOnDeviceWithOrderNCHW();
36 
37  protected:
38  // parameters
39  T epsilon_;
40  StorageOrder order_;
41 
42  // temp results that get passed to the gradient, but are otherwise stored here
43  Tensor<Context> mean_;
44  Tensor<Context> inv_stdev_;
45 
46  INPUT_TAGS(INPUT, SCALE, BIAS);
47  OUTPUT_TAGS(OUTPUT, MEAN, INV_STDEV);
48 };
49 
50 template <typename T, class Context>
51 class InstanceNormGradientOp : public Operator<Context> {
52  public:
53  USE_OPERATOR_CONTEXT_FUNCTIONS;
54  InstanceNormGradientOp(const OperatorDef& operator_def, Workspace* ws)
55  : Operator<Context>(operator_def, ws),
56  epsilon_(OperatorBase::GetSingleArgument<T>("epsilon", 1e-5f)),
57  order_(StringToStorageOrder(
58  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
59  CAFFE_ENFORCE(epsilon_ >= 0, "Must pass a nonnegative epsilon.");
60  }
62 
63  bool RunOnDevice() {
64  switch (order_) {
65  case StorageOrder::NHWC:
66  return RunOnDeviceWithOrderNHWC();
67  case StorageOrder::NCHW:
68  return RunOnDeviceWithOrderNCHW();
69  default:
70  CAFFE_THROW("Unknown storage order: ", order_);
71  }
72  }
73 
74  bool RunOnDeviceWithOrderNHWC();
75  bool RunOnDeviceWithOrderNCHW();
76 
77  protected:
78  // parameters
79  T epsilon_;
80  StorageOrder order_;
81 
82  // temp results that could get passed through to this gradient, but if not,
83  // are stored here
84  Tensor<Context> mean_;
85  Tensor<Context> inv_stdev_;
86 
87  INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, INV_STDEV);
88  OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
89 };
90 
91 } // namespace caffe2
92 
93 #endif // CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...