You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
7.5 KiB
Python
219 lines
7.5 KiB
Python
5 years ago
|
# Natural Language Toolkit: Confusion Matrices
|
||
|
#
|
||
|
# Copyright (C) 2001-2019 NLTK Project
|
||
|
# Author: Edward Loper <edloper@gmail.com>
|
||
|
# Steven Bird <stevenbird1@gmail.com>
|
||
|
# URL: <http://nltk.org/>
|
||
|
# For license information, see LICENSE.TXT
|
||
|
from __future__ import print_function, unicode_literals
|
||
|
from nltk.probability import FreqDist
|
||
|
from nltk.compat import python_2_unicode_compatible
|
||
|
|
||
|
|
||
|
@python_2_unicode_compatible
|
||
|
class ConfusionMatrix(object):
|
||
|
"""
|
||
|
The confusion matrix between a list of reference values and a
|
||
|
corresponding list of test values. Entry *[r,t]* of this
|
||
|
matrix is a count of the number of times that the reference value
|
||
|
*r* corresponds to the test value *t*. E.g.:
|
||
|
|
||
|
>>> from nltk.metrics import ConfusionMatrix
|
||
|
>>> ref = 'DET NN VB DET JJ NN NN IN DET NN'.split()
|
||
|
>>> test = 'DET VB VB DET NN NN NN IN DET NN'.split()
|
||
|
>>> cm = ConfusionMatrix(ref, test)
|
||
|
>>> print(cm['NN', 'NN'])
|
||
|
3
|
||
|
|
||
|
Note that the diagonal entries *Ri=Tj* of this matrix
|
||
|
corresponds to correct values; and the off-diagonal entries
|
||
|
correspond to incorrect values.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, reference, test, sort_by_count=False):
|
||
|
"""
|
||
|
Construct a new confusion matrix from a list of reference
|
||
|
values and a corresponding list of test values.
|
||
|
|
||
|
:type reference: list
|
||
|
:param reference: An ordered list of reference values.
|
||
|
:type test: list
|
||
|
:param test: A list of values to compare against the
|
||
|
corresponding reference values.
|
||
|
:raise ValueError: If ``reference`` and ``length`` do not have
|
||
|
the same length.
|
||
|
"""
|
||
|
if len(reference) != len(test):
|
||
|
raise ValueError('Lists must have the same length.')
|
||
|
|
||
|
# Get a list of all values.
|
||
|
if sort_by_count:
|
||
|
ref_fdist = FreqDist(reference)
|
||
|
test_fdist = FreqDist(test)
|
||
|
|
||
|
def key(v):
|
||
|
return -(ref_fdist[v] + test_fdist[v])
|
||
|
|
||
|
values = sorted(set(reference + test), key=key)
|
||
|
else:
|
||
|
values = sorted(set(reference + test))
|
||
|
|
||
|
# Construct a value->index dictionary
|
||
|
indices = dict((val, i) for (i, val) in enumerate(values))
|
||
|
|
||
|
# Make a confusion matrix table.
|
||
|
confusion = [[0 for val in values] for val in values]
|
||
|
max_conf = 0 # Maximum confusion
|
||
|
for w, g in zip(reference, test):
|
||
|
confusion[indices[w]][indices[g]] += 1
|
||
|
max_conf = max(max_conf, confusion[indices[w]][indices[g]])
|
||
|
|
||
|
#: A list of all values in ``reference`` or ``test``.
|
||
|
self._values = values
|
||
|
#: A dictionary mapping values in ``self._values`` to their indices.
|
||
|
self._indices = indices
|
||
|
#: The confusion matrix itself (as a list of lists of counts).
|
||
|
self._confusion = confusion
|
||
|
#: The greatest count in ``self._confusion`` (used for printing).
|
||
|
self._max_conf = max_conf
|
||
|
#: The total number of values in the confusion matrix.
|
||
|
self._total = len(reference)
|
||
|
#: The number of correct (on-diagonal) values in the matrix.
|
||
|
self._correct = sum(confusion[i][i] for i in range(len(values)))
|
||
|
|
||
|
def __getitem__(self, li_lj_tuple):
|
||
|
"""
|
||
|
:return: The number of times that value ``li`` was expected and
|
||
|
value ``lj`` was given.
|
||
|
:rtype: int
|
||
|
"""
|
||
|
(li, lj) = li_lj_tuple
|
||
|
i = self._indices[li]
|
||
|
j = self._indices[lj]
|
||
|
return self._confusion[i][j]
|
||
|
|
||
|
def __repr__(self):
|
||
|
return '<ConfusionMatrix: %s/%s correct>' % (self._correct, self._total)
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.pretty_format()
|
||
|
|
||
|
def pretty_format(
|
||
|
self,
|
||
|
show_percents=False,
|
||
|
values_in_chart=True,
|
||
|
truncate=None,
|
||
|
sort_by_count=False,
|
||
|
):
|
||
|
"""
|
||
|
:return: A multi-line string representation of this confusion matrix.
|
||
|
:type truncate: int
|
||
|
:param truncate: If specified, then only show the specified
|
||
|
number of values. Any sorting (e.g., sort_by_count)
|
||
|
will be performed before truncation.
|
||
|
:param sort_by_count: If true, then sort by the count of each
|
||
|
label in the reference data. I.e., labels that occur more
|
||
|
frequently in the reference label will be towards the left
|
||
|
edge of the matrix, and labels that occur less frequently
|
||
|
will be towards the right edge.
|
||
|
|
||
|
@todo: add marginals?
|
||
|
"""
|
||
|
confusion = self._confusion
|
||
|
|
||
|
values = self._values
|
||
|
if sort_by_count:
|
||
|
values = sorted(
|
||
|
values, key=lambda v: -sum(self._confusion[self._indices[v]])
|
||
|
)
|
||
|
|
||
|
if truncate:
|
||
|
values = values[:truncate]
|
||
|
|
||
|
if values_in_chart:
|
||
|
value_strings = ["%s" % val for val in values]
|
||
|
else:
|
||
|
value_strings = [str(n + 1) for n in range(len(values))]
|
||
|
|
||
|
# Construct a format string for row values
|
||
|
valuelen = max(len(val) for val in value_strings)
|
||
|
value_format = '%' + repr(valuelen) + 's | '
|
||
|
# Construct a format string for matrix entries
|
||
|
if show_percents:
|
||
|
entrylen = 6
|
||
|
entry_format = '%5.1f%%'
|
||
|
zerostr = ' .'
|
||
|
else:
|
||
|
entrylen = len(repr(self._max_conf))
|
||
|
entry_format = '%' + repr(entrylen) + 'd'
|
||
|
zerostr = ' ' * (entrylen - 1) + '.'
|
||
|
|
||
|
# Write the column values.
|
||
|
s = ''
|
||
|
for i in range(valuelen):
|
||
|
s += (' ' * valuelen) + ' |'
|
||
|
for val in value_strings:
|
||
|
if i >= valuelen - len(val):
|
||
|
s += val[i - valuelen + len(val)].rjust(entrylen + 1)
|
||
|
else:
|
||
|
s += ' ' * (entrylen + 1)
|
||
|
s += ' |\n'
|
||
|
|
||
|
# Write a dividing line
|
||
|
s += '%s-+-%s+\n' % ('-' * valuelen, '-' * ((entrylen + 1) * len(values)))
|
||
|
|
||
|
# Write the entries.
|
||
|
for val, li in zip(value_strings, values):
|
||
|
i = self._indices[li]
|
||
|
s += value_format % val
|
||
|
for lj in values:
|
||
|
j = self._indices[lj]
|
||
|
if confusion[i][j] == 0:
|
||
|
s += zerostr
|
||
|
elif show_percents:
|
||
|
s += entry_format % (100.0 * confusion[i][j] / self._total)
|
||
|
else:
|
||
|
s += entry_format % confusion[i][j]
|
||
|
if i == j:
|
||
|
prevspace = s.rfind(' ')
|
||
|
s = s[:prevspace] + '<' + s[prevspace + 1 :] + '>'
|
||
|
else:
|
||
|
s += ' '
|
||
|
s += '|\n'
|
||
|
|
||
|
# Write a dividing line
|
||
|
s += '%s-+-%s+\n' % ('-' * valuelen, '-' * ((entrylen + 1) * len(values)))
|
||
|
|
||
|
# Write a key
|
||
|
s += '(row = reference; col = test)\n'
|
||
|
if not values_in_chart:
|
||
|
s += 'Value key:\n'
|
||
|
for i, value in enumerate(values):
|
||
|
s += '%6d: %s\n' % (i + 1, value)
|
||
|
|
||
|
return s
|
||
|
|
||
|
def key(self):
|
||
|
values = self._values
|
||
|
str = 'Value key:\n'
|
||
|
indexlen = len(repr(len(values) - 1))
|
||
|
key_format = ' %' + repr(indexlen) + 'd: %s\n'
|
||
|
for i in range(len(values)):
|
||
|
str += key_format % (i, values[i])
|
||
|
|
||
|
return str
|
||
|
|
||
|
|
||
|
def demo():
|
||
|
reference = 'DET NN VB DET JJ NN NN IN DET NN'.split()
|
||
|
test = 'DET VB VB DET NN NN NN IN DET NN'.split()
|
||
|
print('Reference =', reference)
|
||
|
print('Test =', test)
|
||
|
print('Confusion matrix:')
|
||
|
print(ConfusionMatrix(reference, test))
|
||
|
print(ConfusionMatrix(reference, test).pretty_format(sort_by_count=True))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
demo()
|