You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
4.3 KiB
Python

# Natural Language Toolkit: evaluation of dependency parser
#
# Author: Long Duong <longdt219@gmail.com>
#
# Copyright (C) 2001-2020 NLTK Project
# URL: <http://nltk.org/>
# For license information, see LICENSE.TXT
import unicodedata
class DependencyEvaluator(object):
"""
Class for measuring labelled and unlabelled attachment score for
dependency parsing. Note that the evaluation ignores punctuation.
>>> from nltk.parse import DependencyGraph, DependencyEvaluator
>>> gold_sent = DependencyGraph(\"""
... Pierre NNP 2 NMOD
... Vinken NNP 8 SUB
... , , 2 P
... 61 CD 5 NMOD
... years NNS 6 AMOD
... old JJ 2 NMOD
... , , 2 P
... will MD 0 ROOT
... join VB 8 VC
... the DT 11 NMOD
... board NN 9 OBJ
... as IN 9 VMOD
... a DT 15 NMOD
... nonexecutive JJ 15 NMOD
... director NN 12 PMOD
... Nov. NNP 9 VMOD
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> parsed_sent = DependencyGraph(\"""
... Pierre NNP 8 NMOD
... Vinken NNP 1 SUB
... , , 3 P
... 61 CD 6 NMOD
... years NNS 6 AMOD
... old JJ 2 NMOD
... , , 3 AMOD
... will MD 0 ROOT
... join VB 8 VC
... the DT 11 AMOD
... board NN 9 OBJECT
... as IN 9 NMOD
... a DT 15 NMOD
... nonexecutive JJ 15 NMOD
... director NN 12 PMOD
... Nov. NNP 9 VMOD
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> de = DependencyEvaluator([parsed_sent],[gold_sent])
>>> las, uas = de.eval()
>>> las
0.6...
>>> uas
0.8...
>>> abs(uas - 0.8) < 0.00001
True
"""
def __init__(self, parsed_sents, gold_sents):
"""
:param parsed_sents: the list of parsed_sents as the output of parser
:type parsed_sents: list(DependencyGraph)
"""
self._parsed_sents = parsed_sents
self._gold_sents = gold_sents
def _remove_punct(self, inStr):
"""
Function to remove punctuation from Unicode string.
:param input: the input string
:return: Unicode string after remove all punctuation
"""
punc_cat = set(["Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"])
return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)
def eval(self):
"""
Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)
:return : tuple(float,float)
"""
if len(self._parsed_sents) != len(self._gold_sents):
raise ValueError(
" Number of parsed sentence is different with number of gold sentence."
)
corr = 0
corrL = 0
total = 0
for i in range(len(self._parsed_sents)):
parsed_sent_nodes = self._parsed_sents[i].nodes
gold_sent_nodes = self._gold_sents[i].nodes
if len(parsed_sent_nodes) != len(gold_sent_nodes):
raise ValueError("Sentences must have equal length.")
for parsed_node_address, parsed_node in parsed_sent_nodes.items():
gold_node = gold_sent_nodes[parsed_node_address]
if parsed_node["word"] is None:
continue
if parsed_node["word"] != gold_node["word"]:
raise ValueError("Sentence sequence is not matched.")
# Ignore if word is punctuation by default
# if (parsed_sent[j]["word"] in string.punctuation):
if self._remove_punct(parsed_node["word"]) == "":
continue
total += 1
if parsed_node["head"] == gold_node["head"]:
corr += 1
if parsed_node["rel"] == gold_node["rel"]:
corrL += 1
return corrL / total, corr / total