Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

# Natural Language Toolkit: Confusion Matrices 

# 

# Copyright (C) 2001-2012 NLTK Project 

# Author: Edward Loper <edloper@gradient.cis.upenn.edu> 

#         Steven Bird <sb@csse.unimelb.edu.au> 

# URL: <http://www.nltk.org/> 

# For license information, see LICENSE.TXT 

from __future__ import print_function 

from nltk.probability import FreqDist 

 

class ConfusionMatrix(object): 

    """ 

    The confusion matrix between a list of reference values and a 

    corresponding list of test values.  Entry *[r,t]* of this 

    matrix is a count of the number of times that the reference value 

    *r* corresponds to the test value *t*.  E.g.: 

 

        >>> from nltk.metrics import ConfusionMatrix 

        >>> ref  = 'DET NN VB DET JJ NN NN IN DET NN'.split() 

        >>> test = 'DET VB VB DET NN NN NN IN DET NN'.split() 

        >>> cm = ConfusionMatrix(ref, test) 

        >>> print(cm['NN', 'NN']) 

        3 

 

    Note that the diagonal entries *Ri=Tj* of this matrix 

    corresponds to correct values; and the off-diagonal entries 

    correspond to incorrect values. 

    """ 

 

    def __init__(self, reference, test, sort_by_count=False): 

        """ 

        Construct a new confusion matrix from a list of reference 

        values and a corresponding list of test values. 

 

        :type reference: list 

        :param reference: An ordered list of reference values. 

        :type test: list 

        :param test: A list of values to compare against the 

            corresponding reference values. 

        :raise ValueError: If ``reference`` and ``length`` do not have 

            the same length. 

        """ 

        if len(reference) != len(test): 

            raise ValueError('Lists must have the same length.') 

 

        # Get a list of all values. 

        if sort_by_count: 

            ref_fdist = FreqDist(reference) 

            test_fdist = FreqDist(test) 

            def key(v): return -(ref_fdist[v]+test_fdist[v]) 

            values = sorted(set(reference+test), key=key) 

        else: 

            values = sorted(set(reference+test)) 

 

        # Construct a value->index dictionary 

        indices = dict((val,i) for (i,val) in enumerate(values)) 

 

        # Make a confusion matrix table. 

        confusion = [[0 for val in values] for val in values] 

        max_conf = 0 # Maximum confusion 

        for w,g in zip(reference, test): 

            confusion[indices[w]][indices[g]] += 1 

            max_conf = max(max_conf, confusion[indices[w]][indices[g]]) 

 

        #: A list of all values in ``reference`` or ``test``. 

        self._values = values 

        #: A dictionary mapping values in ``self._values`` to their indices. 

        self._indices = indices 

        #: The confusion matrix itself (as a list of lists of counts). 

        self._confusion = confusion 

        #: The greatest count in ``self._confusion`` (used for printing). 

        self._max_conf = max_conf 

        #: The total number of values in the confusion matrix. 

        self._total = len(reference) 

        #: The number of correct (on-diagonal) values in the matrix. 

        self._correct = sum(confusion[i][i] for i in range(len(values))) 

 

    def __getitem__(self, li_lj_tuple): 

        """ 

        :return: The number of times that value ``li`` was expected and 

        value ``lj`` was given. 

        :rtype: int 

        """ 

        (li, lj) = li_lj_tuple 

        i = self._indices[li] 

        j = self._indices[lj] 

        return self._confusion[i][j] 

 

    def __repr__(self): 

        return '<ConfusionMatrix: %s/%s correct>' % (self._correct, 

                                                     self._total) 

 

    def __str__(self): 

        return self.pp() 

 

    def pp(self, show_percents=False, values_in_chart=True, 

           truncate=None, sort_by_count=False): 

        """ 

        :return: A multi-line string representation of this confusion matrix. 

        :type truncate: int 

        :param truncate: If specified, then only show the specified 

            number of values.  Any sorting (e.g., sort_by_count) 

            will be performed before truncation. 

        :param sort_by_count: If true, then sort by the count of each 

            label in the reference data.  I.e., labels that occur more 

            frequently in the reference label will be towards the left 

            edge of the matrix, and labels that occur less frequently 

            will be towards the right edge. 

 

        @todo: add marginals? 

        """ 

        confusion = self._confusion 

 

        values = self._values 

        if sort_by_count: 

            values = sorted(values, key=lambda v: 

                            -sum(self._confusion[self._indices[v]])) 

 

        if truncate: 

            values = values[:truncate] 

 

        if values_in_chart: 

            value_strings = [str(val) for val in values] 

        else: 

            value_strings = [str(n+1) for n in range(len(values))] 

 

        # Construct a format string for row values 

        valuelen = max(len(val) for val in value_strings) 

        value_format = '%' + repr(valuelen) + 's | ' 

        # Construct a format string for matrix entries 

        if show_percents: 

            entrylen = 6 

            entry_format = '%5.1f%%' 

            zerostr = '     .' 

        else: 

            entrylen = len(repr(self._max_conf)) 

            entry_format = '%' + repr(entrylen) + 'd' 

            zerostr = ' '*(entrylen-1) + '.' 

 

        # Write the column values. 

        s = '' 

        for i in range(valuelen): 

            s += (' '*valuelen)+' |' 

            for val in value_strings: 

                if i >= valuelen-len(val): 

                    s += val[i-valuelen+len(val)].rjust(entrylen+1) 

                else: 

                    s += ' '*(entrylen+1) 

            s += ' |\n' 

 

        # Write a dividing line 

        s += '%s-+-%s+\n' % ('-'*valuelen, '-'*((entrylen+1)*len(values))) 

 

        # Write the entries. 

        for val, li in zip(value_strings, values): 

            i = self._indices[li] 

            s += value_format % val 

            for lj in values: 

                j = self._indices[lj] 

                if confusion[i][j] == 0: 

                    s += zerostr 

                elif show_percents: 

                    s += entry_format % (100.0*confusion[i][j]/self._total) 

                else: 

                    s += entry_format % confusion[i][j] 

                if i == j: 

                    prevspace = s.rfind(' ') 

                    s = s[:prevspace] + '<' + s[prevspace+1:] + '>' 

                else: s += ' ' 

            s += '|\n' 

 

        # Write a dividing line 

        s += '%s-+-%s+\n' % ('-'*valuelen, '-'*((entrylen+1)*len(values))) 

 

        # Write a key 

        s += '(row = reference; col = test)\n' 

        if not values_in_chart: 

            s += 'Value key:\n' 

            for i, value in enumerate(values): 

                s += '%6d: %s\n' % (i+1, value) 

 

        return s 

 

    def key(self): 

        values = self._values 

        str = 'Value key:\n' 

        indexlen = len(repr(len(values)-1)) 

        key_format = '  %'+repr(indexlen)+'d: %s\n' 

        for i in range(len(values)): 

            str += key_format % (i, values[i]) 

 

        return str 

 

def demo(): 

    reference = 'DET NN VB DET JJ NN NN IN DET NN'.split() 

    test    = 'DET VB VB DET NN NN NN IN DET NN'.split() 

    print('Reference =', reference) 

    print('Test    =', test) 

    print('Confusion matrix:') 

    print(ConfusionMatrix(reference, test)) 

    print(ConfusionMatrix(reference, test).pp(sort_by_count=True)) 

 

if __name__ == '__main__': 

    demo()