You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
3.5 KiB
Python

import unittest
from nltk.corpus import brown
from nltk.jsontags import JSONTaggedDecoder, JSONTaggedEncoder
from nltk.tag import DefaultTagger, RegexpTagger, AffixTagger
from nltk.tag import UnigramTagger, BigramTagger, TrigramTagger, NgramTagger
from nltk.tag import PerceptronTagger
from nltk.tag import BrillTaggerTrainer, BrillTagger
from nltk.tag.brill import nltkdemo18
class TestJSONSerialization(unittest.TestCase):
def setUp(self):
self.corpus = brown.tagged_sents()[:35]
self.decoder = JSONTaggedDecoder()
self.encoder = JSONTaggedEncoder()
self.default_tagger = DefaultTagger("NN")
def test_default_tagger(self):
encoded = self.encoder.encode(self.default_tagger)
decoded = self.decoder.decode(encoded)
self.assertEqual(repr(self.default_tagger), repr(decoded))
self.assertEqual(self.default_tagger._tag, decoded._tag)
def test_regexp_tagger(self):
tagger = RegexpTagger([(r".*", "NN")], backoff=self.default_tagger)
encoded = self.encoder.encode(tagger)
decoded = self.decoder.decode(encoded)
self.assertEqual(repr(tagger), repr(decoded))
self.assertEqual(repr(tagger.backoff), repr(decoded.backoff))
self.assertEqual(tagger._regexps, decoded._regexps)
def test_affix_tagger(self):
tagger = AffixTagger(self.corpus, backoff=self.default_tagger)
encoded = self.encoder.encode(tagger)
decoded = self.decoder.decode(encoded)
self.assertEqual(repr(tagger), repr(decoded))
self.assertEqual(repr(tagger.backoff), repr(decoded.backoff))
self.assertEqual(tagger._affix_length, decoded._affix_length)
self.assertEqual(tagger._min_word_length, decoded._min_word_length)
self.assertEqual(tagger._context_to_tag, decoded._context_to_tag)
def test_ngram_taggers(self):
unitagger = UnigramTagger(self.corpus, backoff=self.default_tagger)
bitagger = BigramTagger(self.corpus, backoff=unitagger)
tritagger = TrigramTagger(self.corpus, backoff=bitagger)
ntagger = NgramTagger(4, self.corpus, backoff=tritagger)
encoded = self.encoder.encode(ntagger)
decoded = self.decoder.decode(encoded)
self.assertEqual(repr(ntagger), repr(decoded))
self.assertEqual(repr(tritagger), repr(decoded.backoff))
self.assertEqual(repr(bitagger), repr(decoded.backoff.backoff))
self.assertEqual(repr(unitagger), repr(decoded.backoff.backoff.backoff))
self.assertEqual(repr(self.default_tagger),
repr(decoded.backoff.backoff.backoff.backoff))
def test_perceptron_tagger(self):
tagger = PerceptronTagger(load=False)
tagger.train(self.corpus)
encoded = self.encoder.encode(tagger)
decoded = self.decoder.decode(encoded)
self.assertEqual(tagger.model.weights, decoded.model.weights)
self.assertEqual(tagger.tagdict, decoded.tagdict)
self.assertEqual(tagger.classes, decoded.classes)
def test_brill_tagger(self):
trainer = BrillTaggerTrainer(self.default_tagger, nltkdemo18(),
deterministic=True)
tagger = trainer.train(self.corpus, max_rules=30)
encoded = self.encoder.encode(tagger)
decoded = self.decoder.decode(encoded)
self.assertEqual(repr(tagger._initial_tagger),
repr(decoded._initial_tagger))
self.assertEqual(tagger._rules, decoded._rules)
self.assertEqual(tagger._training_stats, decoded._training_stats)