1 #include "caffe2/operators/arg_max_op.h" 6 bool RowWiseArgMaxOp<CPUContext>::RunOnDevice() {
8 auto* result = Output(0);
9 CAFFE_ENFORCE(X.ndim() == 2,
"Input should be a 2D tensor");
10 const int N = X.dim32(0);
11 const int D = X.dim32(1);
12 const float* X_data = X.data<
float>();
14 int* result_data = result->mutable_data<
int>();
15 for (
int n = 0; n < N; ++n) {
16 float mx = X_data[n * D];
18 for (
int d = 1; d < D; ++d) {
20 if (X_data[idx] > mx) {
24 result_data[n] = argmx - (n * D);
31 REGISTER_CPU_OPERATOR(RowWiseArgMax, RowWiseArgMaxOp<CPUContext>);
32 OPERATOR_SCHEMA(RowWiseArgMax)
36 Given a 2D (N X D) input tensor, this operator returns a 2D (N X 1) output 37 tensor with with the index of the maximum value in each row. If there are 38 duplicate max values in a row the index of the first occurence is returned. 40 .Input(0, "X",
"2D (N X D) input tensor")
41 .Output(0,
"Z",
"2D (N X 1) output tensor");
43 NO_GRADIENT(RowWiseArgMax);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...