Caffe2 - C++ API
A deep learning, cross platform ML framework
mean_op.h
1 #ifndef CAFFE2_OPERATORS_MEAN_OPS_H_
2 #define CAFFE2_OPERATORS_MEAN_OPS_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/core/types.h"
9 #include "caffe2/utils/math.h"
10 #include "caffe2/utils/proto_utils.h"
11 
12 namespace caffe2 {
13 
14 template <class Context>
15 class MeanOp final : public Operator<Context> {
16  public:
17  USE_OPERATOR_CONTEXT_FUNCTIONS;
18  USE_SIMPLE_CTOR_DTOR(MeanOp)
19 
20  template <typename T>
21  bool DoRunWithType() {
22  auto& input0 = Input(0);
23  auto* output = Output(0);
24 
25  output->ResizeLike(input0);
26  output->CopyFrom(input0, &context_);
27 
28  if (InputSize() == 1) {
29  return true;
30  }
31 
32  // Dimension checking
33  for (int i = 1; i < InputSize(); ++i) {
34  if (output->dims() != Input(i).dims()) {
35  CAFFE_THROW(
36  "Check failed: output->dims() == Input(i).dims().",
37  "Description: Input #",
38  i,
39  ", input dimension:",
40  Input(i).dims(),
41  " should match output dimension: ",
42  output->dims());
43  }
44  }
45 
46  T* output_data = output->template mutable_data<T>();
47  for (int i = 1; i < InputSize(); ++i) {
48  math::Add(
49  output->size(),
50  output_data,
51  Input(i).template data<T>(),
52  output_data,
53  &context_);
54  }
55 
56  math::Scale(
57  output->size(),
58  1.0f / InputSize(),
59  output_data,
60  output_data,
61  &context_);
62 
63  return true;
64  }
65 
66  bool RunOnDevice() override {
67  if (Input(0).template IsType<float>()) {
68  return DoRunWithType<float>();
69  } else {
70  CAFFE_THROW(
71  "Mean operator only supports 32-bit float, but",
72  " input was of type ",
73  Input(0).meta().name());
74  }
75  }
76 };
77 
78 template <class Context>
79 class MeanGradientOp : public Operator<Context> {
80  public:
81  USE_OPERATOR_CONTEXT_FUNCTIONS;
82 
83  MeanGradientOp(const OperatorDef& operator_def, Workspace* ws)
84  : Operator<Context>(operator_def, ws) {}
85 
86  template <typename T>
87  bool DoRunWithType() {
88  auto& dY = Input(0);
89  const auto* dY_data = dY.template data<T>();
90  int size = dY.size();
91 
92  int num_inputs = OutputSize();
93  float scale = 1.0f / num_inputs;
94 
95  // dX0 = scale * dY
96  auto* dX0 = Output(0);
97  dX0->ResizeLike(dY);
98  math::Scale(
99  size, scale, dY_data, dX0->template mutable_data<T>(), &context_);
100 
101  // Copy the rest dX
102  for (int i = 1; i < num_inputs; i++) {
103  auto* cur_dX = Output(i);
104  cur_dX->ResizeLike(dY);
105  cur_dX->CopyFrom(*dX0, &context_);
106  }
107 
108  return true;
109  }
110 
111  bool RunOnDevice() override {
112  if (Input(0).template IsType<float>()) {
113  return DoRunWithType<float>();
114  } else {
115  CAFFE_THROW(
116  "Mean operator only supports 32-bit float, but",
117  " input was of type ",
118  Input(0).meta().name());
119  }
120  }
121 };
122 
123 } // namespace caffe2
124 
125 #endif // CAFFE2_OPERATORS_MEAN_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...