Caffe2 - C++ API
A deep learning, cross platform ML framework
multi_class_accuracy_op.cc
1 #include "caffe2/operators/multi_class_accuracy_op.h"
2 
3 namespace caffe2 {
4 
5 template <>
6 bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
7  auto& X = Input(PREDICTION);
8  auto& label = Input(LABEL);
9  auto* Y0 = Output(0);
10  auto* Y1 = Output(1);
11  DCHECK_EQ(X.ndim(), 2);
12  // amount, number of instances
13  int N = X.dim32(0);
14  // dimension, number of classes
15  int D = X.dim32(1);
16  DCHECK_EQ(label.ndim(), 1);
17  DCHECK_EQ(label.dim32(0), N);
18  Y0->Resize(D);
19  Y1->Resize(D);
20 
21  const auto* Xdata = X.data<float>();
22  const auto* labeldata = label.data<int>();
23  auto* accuracies = Y0->mutable_data<float>();
24  auto* amounts = Y1->mutable_data<int>();
25  std::fill(accuracies, accuracies + D, 0);
26  std::fill(amounts, amounts + D, 0);
27 
28  for (int i = 0; i < N; ++i) {
29  float maxval = std::numeric_limits<float>::lowest();
30  int maxid = 0;
31  for (int j = 0; j < D; ++j) {
32  if (Xdata[i * D + j] > maxval) {
33  maxval = Xdata[i * D + j];
34  maxid = j;
35  }
36  }
37  int labelid = labeldata[i];
38  DCHECK_LT(labelid, D);
39  if (maxid == labelid) {
40  accuracies[labelid]++;
41  }
42  amounts[labelid]++;
43  }
44 
45  for (int i = 0; i < D; ++i) {
46  int amount = amounts[i];
47  if (amount) {
48  accuracies[i] /= amount;
49  }
50  }
51 
52  return true;
53 }
54 
55 REGISTER_CPU_OPERATOR(
56  MultiClassAccuracy, MultiClassAccuracyOp<float, CPUContext>);
57 
58 OPERATOR_SCHEMA(MultiClassAccuracy)
59  .NumInputs(2)
60  .NumOutputs(2)
61  .SetDoc(R"DOC(
62 Respectively compute accuracy score for each class given a number of instances
63 and predicted scores of each class for each instance.
64 )DOC")
65  .Input(
66  0,
67  "prediction",
68  "2-D float tensor (N,D,) of predicted scores of each class for "
69  "each data. N is the number of instances, i.e., batch size. D is number of "
70  "possible classes/labels.")
71  .Input(
72  1,
73  "labels",
74  "1-D int tensor (N,) of labels for each instance.")
75  .Output(
76  0,
77  "accuracies",
78  "1-D float tensor (D,) of accuracy for each class. If a class has no "
79  "instance in the batch, its accuracy score is set to zero.")
80  .Output(
81  1,
82  "amounts",
83  "1-D int tensor (D,) of number of instances for each class in the batch.");
84 
85 SHOULD_NOT_DO_GRADIENT(MultiClassAccuracy);
86 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...