Caffe2 - C++ API
A deep learning, cross platform ML framework
arg_max_op.cc
1 #include "caffe2/operators/arg_max_op.h"
2 
3 namespace caffe2 {
4 
5 template <>
6 bool RowWiseArgMaxOp<CPUContext>::RunOnDevice() {
7  auto& X = Input(0);
8  auto* result = Output(0);
9  CAFFE_ENFORCE(X.ndim() == 2, "Input should be a 2D tensor");
10  const int N = X.dim32(0);
11  const int D = X.dim32(1);
12  const float* X_data = X.data<float>();
13  result->Resize(N, 1);
14  int* result_data = result->mutable_data<int>();
15  for (int n = 0; n < N; ++n) {
16  float mx = X_data[n * D];
17  int argmx = n * D;
18  for (int d = 1; d < D; ++d) {
19  int idx = n * D + d;
20  if (X_data[idx] > mx) {
21  mx = X_data[idx];
22  argmx = idx;
23  }
24  result_data[n] = argmx - (n * D);
25  }
26  }
27  return true;
28 }
29 
30 // RowWiseArgMax
31 REGISTER_CPU_OPERATOR(RowWiseArgMax, RowWiseArgMaxOp<CPUContext>);
32 OPERATOR_SCHEMA(RowWiseArgMax)
33  .NumInputs(1)
34  .NumOutputs(1)
35  .SetDoc(R"DOC(
36  Given a 2D (N X D) input tensor, this operator returns a 2D (N X 1) output
37  tensor with with the index of the maximum value in each row. If there are
38  duplicate max values in a row the index of the first occurence is returned.
39  )DOC")
40  .Input(0, "X", "2D (N X D) input tensor")
41  .Output(0, "Z", "2D (N X 1) output tensor");
42 
43 NO_GRADIENT(RowWiseArgMax);
44 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...