1 #include "caffe2/operators/multi_class_accuracy_op.h" 6 bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
7 auto& X = Input(PREDICTION);
8 auto& label = Input(LABEL);
11 DCHECK_EQ(X.ndim(), 2);
16 DCHECK_EQ(label.ndim(), 1);
17 DCHECK_EQ(label.dim32(0), N);
21 const auto* Xdata = X.data<
float>();
22 const auto* labeldata = label.data<
int>();
23 auto* accuracies = Y0->mutable_data<
float>();
24 auto* amounts = Y1->mutable_data<
int>();
25 std::fill(accuracies, accuracies + D, 0);
26 std::fill(amounts, amounts + D, 0);
28 for (
int i = 0; i < N; ++i) {
29 float maxval = std::numeric_limits<float>::lowest();
31 for (
int j = 0; j < D; ++j) {
32 if (Xdata[i * D + j] > maxval) {
33 maxval = Xdata[i * D + j];
37 int labelid = labeldata[i];
38 DCHECK_LT(labelid, D);
39 if (maxid == labelid) {
40 accuracies[labelid]++;
45 for (
int i = 0; i < D; ++i) {
46 int amount = amounts[i];
48 accuracies[i] /= amount;
55 REGISTER_CPU_OPERATOR(
56 MultiClassAccuracy, MultiClassAccuracyOp<float, CPUContext>);
58 OPERATOR_SCHEMA(MultiClassAccuracy)
62 Respectively compute accuracy score for each class given a number of instances 63 and predicted scores of each class for each instance. 68 "2-D float tensor (N,D,) of predicted scores of each class for " 69 "each data. N is the number of instances, i.e., batch size. D is number of " 70 "possible classes/labels.")
74 "1-D int tensor (N,) of labels for each instance.")
78 "1-D float tensor (D,) of accuracy for each class. If a class has no " 79 "instance in the batch, its accuracy score is set to zero.")
83 "1-D int tensor (D,) of number of instances for each class in the batch.");
85 SHOULD_NOT_DO_GRADIENT(MultiClassAccuracy);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...