1 #include "caffe2/operators/arg_ops.h" 5 #include "caffe2/operators/arg_ops_eigen.h" 6 #include "caffe2/utils/math.h" 12 template <
typename T,
class Compare,
class Context>
15 const TIndex prev_size,
16 const TIndex next_size,
21 math::Set<TIndex, Context>(prev_size * next_size, TIndex(0), Y, context);
22 for (TIndex i = 0; i < prev_size; ++i) {
23 const T* cur_X = X + i * n * next_size + next_size;
24 for (TIndex k = 1; k < n; ++k) {
25 for (TIndex j = 0; j < next_size; ++j) {
26 TIndex* cur_Y = Y + i * next_size + j;
27 if (comp(*cur_X, X[i * n * next_size + *cur_Y * next_size + j])) {
38 template <
typename T,
class Context>
39 bool ArgMaxOp<T, Context>::Compute(
41 const TIndex prev_size,
42 const TIndex next_size,
45 #if EIGEN_VERSION_AT_LEAST(3, 3, 0) 46 arg_ops_eigen::ComputeArgMaxEigen(
47 Eigen::DefaultDevice(), X, prev_size, next_size, n, Y);
48 #else // EIGEN_VERSION_AT_LEAST(3, 3, 0) 49 ComputeArgImpl(X, prev_size, next_size, n, std::greater<T>(), Y, &context_);
50 #endif // EIGEN_VERSION_AT_LEAST(3, 3, 0) 54 template <
typename T,
class Context>
55 bool ArgMinOp<T, Context>::Compute(
57 const TIndex prev_size,
58 const TIndex next_size,
61 #if EIGEN_VERSION_AT_LEAST(3, 3, 0) 62 arg_ops_eigen::ComputeArgMinEigen(
63 Eigen::DefaultDevice(), X, prev_size, next_size, n, Y);
64 #else // EIGEN_VERSION_AT_LEAST(3, 3, 0) 65 ComputeArgImpl(X, prev_size, next_size, n, std::less<T>(), Y, &context_);
66 #endif // EIGEN_VERSION_AT_LEAST(3, 3, 0) 70 REGISTER_CPU_OPERATOR(ArgMax, ArgMaxOp<float, CPUContext>);
71 REGISTER_CPU_OPERATOR(ArgMin, ArgMinOp<float, CPUContext>);
75 std::vector<TensorShape> InferTensor(
76 const OperatorDef& def,
77 const std::vector<TensorShape>& in) {
78 std::vector<TensorShape> out(1);
79 ArgumentHelper helper(def);
80 int axis = helper.GetSingleArgument(
"axis", -1);
81 const bool keep_dims = helper.GetSingleArgument(
"keepdims",
true);
83 axis = in[0].dims_size();
85 const auto& in_dims = in[0].dims();
86 auto* out_dims = out[0].mutable_dims();
87 for (
int i = 0; i < axis; ++i) {
88 out_dims->Add(in_dims.Get(i));
93 for (
int i = axis + 1; i < in_dims.size(); ++i) {
94 out_dims->Add(in_dims.Get(i));
96 out[0].set_data_type(TensorProto::INT64);
102 OPERATOR_SCHEMA(ArgMax)
105 .TensorInferenceFunction(InferTensor)
107 Retrive the argmax of the axis dimension. Given an input tensor of shape 108 [a_0, a_1, ..., a_{n-1}] and two arguments axis as int and keepdims as bool, 110 - Index tensor which contains the indices of the largest element. It has the 111 same dims as X.dims() with the dimension along axis equals 1 when 112 keepdims == true otherwise removed. 114 .Input(0, "X",
"Tenor of shape [a_0, a_1, ..., a_{n-1}].")
115 .Output(0,
"Indices",
"Tensor of indices for the largest values.")
116 .Arg(
"axis",
"The axis to get argmax.")
117 .Arg(
"keepdims",
"Whether to keep the axis dim in the output.");
119 OPERATOR_SCHEMA(ArgMin)
122 .TensorInferenceFunction(InferTensor)
124 Retrive the argmin of the axis dimension. Given an input tensor of shape 125 [a_0, a_1, ..., a_{n-1}] and two arguments axis as int and keepdims as bool, 127 - Index tensor which contains the indices of the largest element. It has the 128 same dims as X.dims() with the dimension along axis equals 1 when 129 keepdims == true otherwise removed. 131 .Input(0, "X",
"Tenor of shape [a_0, a_1, ..., a_{n-1}].")
132 .Output(0,
"Indices",
"Tensor of indices for the largest values.")
133 .Arg(
"axis",
"The axis to get argmin.")
134 .Arg(
"keepdims",
"Whether to keep the axis dim in the output.");
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...