Caffe2 - C++ API
A deep learning, cross platform ML framework
arg_ops.h
1 #ifndef CAFFE2_OPERATORS_ARG_OPS_H_
2 #define CAFFE2_OPERATORS_ARG_OPS_H_
3 
4 #include <algorithm>
5 #include <iterator>
6 #include <vector>
7 
8 #include "caffe2/core/context.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/types.h"
11 
12 namespace caffe2 {
13 
14 template <typename T, class Context>
15 class ArgOpBase : public Operator<Context> {
16  public:
17  USE_OPERATOR_CONTEXT_FUNCTIONS;
18 
19  ArgOpBase(const OperatorDef& operator_def, Workspace* ws)
20  : Operator<Context>(operator_def, ws),
21  OP_SINGLE_ARG(int, "axis", axis_, -1),
22  OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {}
23 
24  bool RunOnDevice() override {
25  const auto& X = Input(0);
26  auto* Y = Output(0);
27  const int ndim = X.ndim();
28  if (axis_ == -1) {
29  axis_ = ndim - 1;
30  }
31  CAFFE_ENFORCE_GE(axis_, 0);
32  CAFFE_ENFORCE_LT(axis_, ndim);
33  const std::vector<TIndex>& X_dims = X.dims();
34  std::vector<TIndex> Y_dims;
35  Y_dims.reserve(ndim);
36  TIndex prev_size = 1;
37  TIndex next_size = 1;
38  for (int i = 0; i < axis_; ++i) {
39  Y_dims.push_back(X_dims[i]);
40  prev_size *= X_dims[i];
41  }
42  if (keep_dims_) {
43  Y_dims.push_back(1);
44  }
45  for (int i = axis_ + 1; i < ndim; ++i) {
46  Y_dims.push_back(X_dims[i]);
47  next_size *= X_dims[i];
48  }
49  Y->Resize(Y_dims);
50  const TIndex n = X_dims[axis_];
51  return Compute(
52  X.template data<T>(),
53  prev_size,
54  next_size,
55  n,
56  Y->template mutable_data<TIndex>());
57  }
58 
59  protected:
60  virtual bool Compute(
61  const T* X,
62  const TIndex prev_size,
63  const TIndex next_size,
64  const TIndex n,
65  TIndex* Y) = 0;
66 
67  private:
68  int axis_;
69  const bool keep_dims_;
70 };
71 
72 template <typename T, class Context>
73 class ArgMaxOp final : public ArgOpBase<T, Context> {
74  public:
75  USE_OPERATOR_CONTEXT_FUNCTIONS;
76 
77  ArgMaxOp(const OperatorDef& operator_def, Workspace* ws)
78  : ArgOpBase<T, Context>(operator_def, ws) {}
79 
80  protected:
81  bool Compute(
82  const T* X,
83  const TIndex prev_size,
84  const TIndex next_size,
85  const TIndex n,
86  TIndex* Y) override;
87 };
88 
89 template <typename T, class Context>
90 class ArgMinOp final : public ArgOpBase<T, Context> {
91  public:
92  USE_OPERATOR_CONTEXT_FUNCTIONS;
93 
94  ArgMinOp(const OperatorDef& operator_def, Workspace* ws)
95  : ArgOpBase<T, Context>(operator_def, ws) {}
96 
97  protected:
98  bool Compute(
99  const T* X,
100  const TIndex prev_size,
101  const TIndex next_size,
102  const TIndex n,
103  TIndex* Y) override;
104 };
105 
106 } // namespace caffe2
107 
108 #endif // CAFFE2_OPERATORS_ARG_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...