You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

353 lines
11 KiB
Python

5 years ago
# Natural Language Toolkit: Chunk parsing API
#
# Copyright (C) 2001-2020 NLTK Project
5 years ago
# Author: Edward Loper <edloper@gmail.com>
# URL: <http://nltk.org/>
# For license information, see LICENSE.TXT
"""
Named entity chunker
"""
import os, re, pickle
from xml.etree import ElementTree as ET
from nltk.tag import ClassifierBasedTagger, pos_tag
try:
from nltk.classify import MaxentClassifier
except ImportError:
pass
from nltk.tree import Tree
from nltk.tokenize import word_tokenize
from nltk.data import find
from nltk.chunk.api import ChunkParserI
from nltk.chunk.util import ChunkScore
class NEChunkParserTagger(ClassifierBasedTagger):
"""
The IOB tagger used by the chunk parser.
"""
def __init__(self, train):
ClassifierBasedTagger.__init__(
self, train=train, classifier_builder=self._classifier_builder
)
def _classifier_builder(self, train):
return MaxentClassifier.train(
train, algorithm="megam", gaussian_prior_sigma=1, trace=2
5 years ago
)
def _english_wordlist(self):
try:
wl = self._en_wordlist
except AttributeError:
from nltk.corpus import words
self._en_wordlist = set(words.words("en-basic"))
5 years ago
wl = self._en_wordlist
return wl
def _feature_detector(self, tokens, index, history):
word = tokens[index][0]
pos = simplify_pos(tokens[index][1])
if index == 0:
prevword = prevprevword = None
prevpos = prevprevpos = None
prevshape = prevtag = prevprevtag = None
elif index == 1:
prevword = tokens[index - 1][0].lower()
prevprevword = None
prevpos = simplify_pos(tokens[index - 1][1])
prevprevpos = None
prevtag = history[index - 1][0]
prevshape = prevprevtag = None
else:
prevword = tokens[index - 1][0].lower()
prevprevword = tokens[index - 2][0].lower()
prevpos = simplify_pos(tokens[index - 1][1])
prevprevpos = simplify_pos(tokens[index - 2][1])
prevtag = history[index - 1]
prevprevtag = history[index - 2]
prevshape = shape(prevword)
if index == len(tokens) - 1:
nextword = nextnextword = None
nextpos = nextnextpos = None
elif index == len(tokens) - 2:
nextword = tokens[index + 1][0].lower()
nextpos = tokens[index + 1][1].lower()
nextnextword = None
nextnextpos = None
else:
nextword = tokens[index + 1][0].lower()
nextpos = tokens[index + 1][1].lower()
nextnextword = tokens[index + 2][0].lower()
nextnextpos = tokens[index + 2][1].lower()
# 89.6
features = {
"bias": True,
"shape": shape(word),
"wordlen": len(word),
"prefix3": word[:3].lower(),
"suffix3": word[-3:].lower(),
"pos": pos,
"word": word,
"en-wordlist": (word in self._english_wordlist()),
"prevtag": prevtag,
"prevpos": prevpos,
"nextpos": nextpos,
"prevword": prevword,
"nextword": nextword,
"word+nextpos": "{0}+{1}".format(word.lower(), nextpos),
"pos+prevtag": "{0}+{1}".format(pos, prevtag),
"shape+prevtag": "{0}+{1}".format(prevshape, prevtag),
5 years ago
}
return features
class NEChunkParser(ChunkParserI):
"""
Expected input: list of pos-tagged words
"""
def __init__(self, train):
self._train(train)
def parse(self, tokens):
"""
Each token should be a pos-tagged word
"""
tagged = self._tagger.tag(tokens)
tree = self._tagged_to_parse(tagged)
return tree
def _train(self, corpus):
# Convert to tagged sequence
corpus = [self._parse_to_tagged(s) for s in corpus]
self._tagger = NEChunkParserTagger(train=corpus)
def _tagged_to_parse(self, tagged_tokens):
"""
Convert a list of tagged tokens to a chunk-parse tree.
"""
sent = Tree("S", [])
5 years ago
for (tok, tag) in tagged_tokens:
if tag == "O":
5 years ago
sent.append(tok)
elif tag.startswith("B-"):
5 years ago
sent.append(Tree(tag[2:], [tok]))
elif tag.startswith("I-"):
5 years ago
if sent and isinstance(sent[-1], Tree) and sent[-1].label() == tag[2:]:
sent[-1].append(tok)
else:
sent.append(Tree(tag[2:], [tok]))
return sent
@staticmethod
def _parse_to_tagged(sent):
"""
Convert a chunk-parse tree to a list of tagged tokens.
"""
toks = []
for child in sent:
if isinstance(child, Tree):
if len(child) == 0:
print("Warning -- empty chunk in sentence")
continue
toks.append((child[0], "B-{0}".format(child.label())))
5 years ago
for tok in child[1:]:
toks.append((tok, "I-{0}".format(child.label())))
5 years ago
else:
toks.append((child, "O"))
5 years ago
return toks
def shape(word):
if re.match("[0-9]+(\.[0-9]*)?|[0-9]*\.[0-9]+$", word, re.UNICODE):
return "number"
elif re.match("\W+$", word, re.UNICODE):
return "punct"
elif re.match("\w+$", word, re.UNICODE):
5 years ago
if word.istitle():
return "upcase"
5 years ago
elif word.islower():
return "downcase"
5 years ago
else:
return "mixedcase"
5 years ago
else:
return "other"
5 years ago
def simplify_pos(s):
if s.startswith("V"):
5 years ago
return "V"
else:
return s.split("-")[0]
5 years ago
def postag_tree(tree):
# Part-of-speech tagging.
words = tree.leaves()
tag_iter = (pos for (word, pos) in pos_tag(words))
newtree = Tree("S", [])
5 years ago
for child in tree:
if isinstance(child, Tree):
newtree.append(Tree(child.label(), []))
for subchild in child:
newtree[-1].append((subchild, next(tag_iter)))
else:
newtree.append((child, next(tag_iter)))
return newtree
def load_ace_data(roots, fmt="binary", skip_bnews=True):
5 years ago
for root in roots:
for root, dirs, files in os.walk(root):
if root.endswith("bnews") and skip_bnews:
5 years ago
continue
for f in files:
if f.endswith(".sgm"):
5 years ago
for sent in load_ace_file(os.path.join(root, f), fmt):
yield sent
def load_ace_file(textfile, fmt):
print(" - {0}".format(os.path.split(textfile)[1]))
annfile = textfile + ".tmx.rdc.xml"
5 years ago
# Read the xml file, and get a list of entities
entities = []
with open(annfile, "r") as infile:
5 years ago
xml = ET.parse(infile).getroot()
for entity in xml.findall("document/entity"):
typ = entity.find("entity_type").text
for mention in entity.findall("entity_mention"):
if mention.get("TYPE") != "NAME":
5 years ago
continue # only NEs
s = int(mention.find("head/charseq/start").text)
e = int(mention.find("head/charseq/end").text) + 1
5 years ago
entities.append((s, e, typ))
# Read the text file, and mark the entities.
with open(textfile, "r") as infile:
5 years ago
text = infile.read()
# Strip XML tags, since they don't count towards the indices
text = re.sub("<(?!/?TEXT)[^>]+>", "", text)
5 years ago
# Blank out anything before/after <TEXT>
def subfunc(m):
return " " * (m.end() - m.start() - 6)
5 years ago
text = re.sub("[\s\S]*<TEXT>", subfunc, text)
text = re.sub("</TEXT>[\s\S]*", "", text)
5 years ago
# Simplify quotes
text = re.sub("``", ' "', text)
text = re.sub("''", '" ', text)
entity_types = set(typ for (s, e, typ) in entities)
# Binary distinction (NE or not NE)
if fmt == "binary":
5 years ago
i = 0
toks = Tree("S", [])
5 years ago
for (s, e, typ) in sorted(entities):
if s < i:
s = i # Overlapping! Deal with this better?
if e <= s:
continue
toks.extend(word_tokenize(text[i:s]))
toks.append(Tree("NE", text[s:e].split()))
5 years ago
i = e
toks.extend(word_tokenize(text[i:]))
yield toks
# Multiclass distinction (NE type)
elif fmt == "multiclass":
5 years ago
i = 0
toks = Tree("S", [])
5 years ago
for (s, e, typ) in sorted(entities):
if s < i:
s = i # Overlapping! Deal with this better?
if e <= s:
continue
toks.extend(word_tokenize(text[i:s]))
toks.append(Tree(typ, text[s:e].split()))
i = e
toks.extend(word_tokenize(text[i:]))
yield toks
else:
raise ValueError("bad fmt value")
5 years ago
# This probably belongs in a more general-purpose location (as does
# the parse_to_tagged function).
def cmp_chunks(correct, guessed):
correct = NEChunkParser._parse_to_tagged(correct)
guessed = NEChunkParser._parse_to_tagged(guessed)
ellipsis = False
for (w, ct), (w, gt) in zip(correct, guessed):
if ct == gt == "O":
5 years ago
if not ellipsis:
print(" {:15} {:15} {2}".format(ct, gt, w))
print(" {:15} {:15} {2}".format("...", "...", "..."))
5 years ago
ellipsis = True
else:
ellipsis = False
print(" {:15} {:15} {2}".format(ct, gt, w))
def build_model(fmt="binary"):
print("Loading training data...")
5 years ago
train_paths = [
find("corpora/ace_data/ace.dev"),
find("corpora/ace_data/ace.heldout"),
find("corpora/ace_data/bbn.dev"),
find("corpora/ace_data/muc.dev"),
5 years ago
]
train_trees = load_ace_data(train_paths, fmt)
train_data = [postag_tree(t) for t in train_trees]
print("Training...")
5 years ago
cp = NEChunkParser(train_data)
del train_data
print("Loading eval data...")
eval_paths = [find("corpora/ace_data/ace.eval")]
5 years ago
eval_trees = load_ace_data(eval_paths, fmt)
eval_data = [postag_tree(t) for t in eval_trees]
print("Evaluating...")
5 years ago
chunkscore = ChunkScore()
for i, correct in enumerate(eval_data):
guess = cp.parse(correct.leaves())
chunkscore.score(correct, guess)
if i < 3:
cmp_chunks(correct, guess)
print(chunkscore)
outfilename = "/tmp/ne_chunker_{0}.pickle".format(fmt)
print("Saving chunker to {0}...".format(outfilename))
5 years ago
with open(outfilename, "wb") as outfile:
5 years ago
pickle.dump(cp, outfile, -1)
return cp
if __name__ == "__main__":
5 years ago
# Make sure that the pickled object has the right class name:
from nltk.chunk.named_entity import build_model
build_model("binary")
build_model("multiclass")