Caffe2 - C++ API
A deep learning, cross platform ML framework
reduce_ops.h
1 #ifndef CAFFE2_OPERATORS_REDUCE_OPS_H_
2 #define CAFFE2_OPERATORS_REDUCE_OPS_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/types.h"
8 #include "caffe2/utils/math.h"
9 #include "caffe2/utils/proto_utils.h"
10 
11 namespace caffe2 {
12 
13 template <typename T, class Context>
14 class ReduceOpBase : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17 
18  ReduceOpBase(const OperatorDef& operator_def, Workspace* ws)
19  : Operator<Context>(operator_def, ws) {
20  axes_ = OperatorBase::GetRepeatedArgument<int>("axes");
21  keepdims_ = OperatorBase::GetSingleArgument<int>("keepdims", 1);
22  }
23 
24  bool RunOnDevice() override {
25  int ndim = Input(0).ndim();
26 
27  if (axes_.empty()) {
28  axes_.resize(ndim);
29  std::iota(axes_.begin(), axes_.end(), 0);
30  } else {
31  std::sort(axes_.begin(), axes_.end());
32  CAFFE_ENFORCE(axes_.front() >= 0, "Axes ids must be non-negative.");
33  CAFFE_ENFORCE(
34  axes_.back() < ndim,
35  "Axes ids must be smaller than the dimensions of input.");
36  }
37 
38  auto& X = Input(0);
39  auto* Y = Output(0);
40 
41  vector<TIndex> y_dims = X.dims();
42  TIndex Y_size = X.size();
43  for (TIndex id = axes_.size() - 1; id >= 0; id--) {
44  TIndex reduced_axis = axes_[id];
45  Y_size /= y_dims[reduced_axis];
46  if (keepdims_) {
47  y_dims[reduced_axis] = 1;
48  } else {
49  y_dims.erase(y_dims.begin() + reduced_axis);
50  }
51  }
52  Y->Resize(y_dims);
53 
54  return this->Compute(
55  X.template data<T>(),
56  X.size(),
57  const_cast<vector<TIndex>&>(X.dims()),
58  Y->template mutable_data<T>(),
59  Y_size,
60  axes_,
61  y_dims,
62  keepdims_);
63  }
64 
65  protected:
66  virtual bool Compute(
67  const T* X_data,
68  const TIndex X_size,
69  vector<TIndex>& dims,
70  T* Y_data,
71  const TIndex Y_size,
72  vector<int>& axes,
73  vector<TIndex>& Y_dims,
74  int keepdims) = 0;
75 
76  private:
77  std::vector<int> axes_;
78  int keepdims_;
79 };
80 
81 template <typename T, class Context>
82 class ReduceSumOp : public ReduceOpBase<T, Context> {
83  public:
84  USE_OPERATOR_CONTEXT_FUNCTIONS;
85 
86  ReduceSumOp(const OperatorDef& operator_def, Workspace* ws)
87  : ReduceOpBase<T, Context>(operator_def, ws) {}
88 
89  protected:
90  bool Compute(
91  const T* X_data,
92  const TIndex X_size,
93  vector<TIndex>& dims,
94  T* Y_data,
95  const TIndex Y_size,
96  vector<int>& axes,
97  vector<TIndex>& Y_dims,
98  int keepdims) override;
99 };
100 
101 template <typename T, class Context>
102 class ReduceMeanOp : public ReduceOpBase<T, Context> {
103  public:
104  USE_OPERATOR_CONTEXT_FUNCTIONS;
105 
106  ReduceMeanOp(const OperatorDef& operator_def, Workspace* ws)
107  : ReduceOpBase<T, Context>(operator_def, ws) {}
108 
109  protected:
110  bool Compute(
111  const T* X_data,
112  const TIndex X_size,
113  vector<TIndex>& dims,
114  T* Y_data,
115  const TIndex Y_size,
116  vector<int>& axes,
117  vector<TIndex>& Y_dims,
118  int keepdims) override;
119 };
120 
121 } // namespace caffe2
122 
123 #endif // CAFFE2_OPERATORS_REDUCE_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...