Caffe2 - C++ API
A deep learning, cross platform ML framework
arg_ops.cc
1 #include "caffe2/operators/arg_ops.h"
2 
3 #include <functional>
4 
5 #include "caffe2/operators/arg_ops_eigen.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 namespace {
11 
12 template <typename T, class Compare, class Context>
13 void ComputeArgImpl(
14  const T* X,
15  const TIndex prev_size,
16  const TIndex next_size,
17  const TIndex n,
18  const Compare& comp,
19  TIndex* Y,
20  Context* context) {
21  math::Set<TIndex, Context>(prev_size * next_size, TIndex(0), Y, context);
22  for (TIndex i = 0; i < prev_size; ++i) {
23  const T* cur_X = X + i * n * next_size + next_size;
24  for (TIndex k = 1; k < n; ++k) {
25  for (TIndex j = 0; j < next_size; ++j) {
26  TIndex* cur_Y = Y + i * next_size + j;
27  if (comp(*cur_X, X[i * n * next_size + *cur_Y * next_size + j])) {
28  *cur_Y = k;
29  }
30  ++cur_X;
31  }
32  }
33  }
34 }
35 
36 } // namespace
37 
38 template <typename T, class Context>
39 bool ArgMaxOp<T, Context>::Compute(
40  const T* X,
41  const TIndex prev_size,
42  const TIndex next_size,
43  const TIndex n,
44  TIndex* Y) {
45 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
46  arg_ops_eigen::ComputeArgMaxEigen(
47  Eigen::DefaultDevice(), X, prev_size, next_size, n, Y);
48 #else // EIGEN_VERSION_AT_LEAST(3, 3, 0)
49  ComputeArgImpl(X, prev_size, next_size, n, std::greater<T>(), Y, &context_);
50 #endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
51  return true;
52 }
53 
54 template <typename T, class Context>
55 bool ArgMinOp<T, Context>::Compute(
56  const T* X,
57  const TIndex prev_size,
58  const TIndex next_size,
59  const TIndex n,
60  TIndex* Y) {
61 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
62  arg_ops_eigen::ComputeArgMinEigen(
63  Eigen::DefaultDevice(), X, prev_size, next_size, n, Y);
64 #else // EIGEN_VERSION_AT_LEAST(3, 3, 0)
65  ComputeArgImpl(X, prev_size, next_size, n, std::less<T>(), Y, &context_);
66 #endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
67  return true;
68 }
69 
70 REGISTER_CPU_OPERATOR(ArgMax, ArgMaxOp<float, CPUContext>);
71 REGISTER_CPU_OPERATOR(ArgMin, ArgMinOp<float, CPUContext>);
72 
73 namespace {
74 
75 std::vector<TensorShape> InferTensor(
76  const OperatorDef& def,
77  const std::vector<TensorShape>& in) {
78  std::vector<TensorShape> out(1);
79  ArgumentHelper helper(def);
80  int axis = helper.GetSingleArgument("axis", -1);
81  const bool keep_dims = helper.GetSingleArgument("keepdims", true);
82  if (axis == -1) {
83  axis = in[0].dims_size();
84  }
85  const auto& in_dims = in[0].dims();
86  auto* out_dims = out[0].mutable_dims();
87  for (int i = 0; i < axis; ++i) {
88  out_dims->Add(in_dims.Get(i));
89  }
90  if (keep_dims) {
91  out_dims->Add(1);
92  }
93  for (int i = axis + 1; i < in_dims.size(); ++i) {
94  out_dims->Add(in_dims.Get(i));
95  }
96  out[0].set_data_type(TensorProto::INT64);
97  return out;
98 }
99 
100 } // namespace
101 
102 OPERATOR_SCHEMA(ArgMax)
103  .NumInputs(1)
104  .NumOutputs(1)
105  .TensorInferenceFunction(InferTensor)
106  .SetDoc(R"DOC(
107 Retrive the argmax of the axis dimension. Given an input tensor of shape
108 [a_0, a_1, ..., a_{n-1}] and two arguments axis as int and keepdims as bool,
109 returns one output:
110 - Index tensor which contains the indices of the largest element. It has the
111  same dims as X.dims() with the dimension along axis equals 1 when
112  keepdims == true otherwise removed.
113  )DOC")
114  .Input(0, "X", "Tenor of shape [a_0, a_1, ..., a_{n-1}].")
115  .Output(0, "Indices", "Tensor of indices for the largest values.")
116  .Arg("axis", "The axis to get argmax.")
117  .Arg("keepdims", "Whether to keep the axis dim in the output.");
118 
119 OPERATOR_SCHEMA(ArgMin)
120  .NumInputs(1)
121  .NumOutputs(1)
122  .TensorInferenceFunction(InferTensor)
123  .SetDoc(R"DOC(
124 Retrive the argmin of the axis dimension. Given an input tensor of shape
125 [a_0, a_1, ..., a_{n-1}] and two arguments axis as int and keepdims as bool,
126 returns one output:
127 - Index tensor which contains the indices of the largest element. It has the
128  same dims as X.dims() with the dimension along axis equals 1 when
129  keepdims == true otherwise removed.
130  )DOC")
131  .Input(0, "X", "Tenor of shape [a_0, a_1, ..., a_{n-1}].")
132  .Output(0, "Indices", "Tensor of indices for the largest values.")
133  .Arg("axis", "The axis to get argmin.")
134  .Arg("keepdims", "Whether to keep the axis dim in the output.");
135 
136 NO_GRADIENT(ArgMax);
137 NO_GRADIENT(ArgMin);
138 
139 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...