# -*- coding: utf-8 -*- # Natural Language Toolkit: Language Model Unit Tests # # Copyright (C) 2001-2019 NLTK Project # Author: Ilia Kurenkov # URL: # For license information, see LICENSE.TXT import unittest from collections import Counter import six from nltk.lm import Vocabulary class NgramModelVocabularyTests(unittest.TestCase): """tests Vocabulary Class""" @classmethod def setUpClass(cls): cls.vocab = Vocabulary( ["z", "a", "b", "c", "f", "d", "e", "g", "a", "d", "b", "e", "w"], unk_cutoff=2, ) def test_truthiness(self): self.assertTrue(self.vocab) def test_cutoff_value_set_correctly(self): self.assertEqual(self.vocab.cutoff, 2) def test_unable_to_change_cutoff(self): with self.assertRaises(AttributeError): self.vocab.cutoff = 3 def test_cutoff_setter_checks_value(self): with self.assertRaises(ValueError) as exc_info: Vocabulary("abc", unk_cutoff=0) expected_error_msg = "Cutoff value cannot be less than 1. Got: 0" self.assertEqual(expected_error_msg, str(exc_info.exception)) def test_counts_set_correctly(self): self.assertEqual(self.vocab.counts["a"], 2) self.assertEqual(self.vocab.counts["b"], 2) self.assertEqual(self.vocab.counts["c"], 1) def test_membership_check_respects_cutoff(self): # a was seen 2 times, so it should be considered part of the vocabulary self.assertTrue("a" in self.vocab) # "c" was seen once, it shouldn't be considered part of the vocab self.assertFalse("c" in self.vocab) # "z" was never seen at all, also shouldn't be considered in the vocab self.assertFalse("z" in self.vocab) def test_vocab_len_respects_cutoff(self): # Vocab size is the number of unique tokens that occur at least as often # as the cutoff value, plus 1 to account for unknown words. self.assertEqual(5, len(self.vocab)) def test_vocab_iter_respects_cutoff(self): vocab_counts = ["a", "b", "c", "d", "e", "f", "g", "w", "z"] vocab_items = ["a", "b", "d", "e", ""] six.assertCountEqual(self, vocab_counts, list(self.vocab.counts.keys())) six.assertCountEqual(self, vocab_items, list(self.vocab)) def test_update_empty_vocab(self): empty = Vocabulary(unk_cutoff=2) self.assertEqual(len(empty), 0) self.assertFalse(empty) self.assertIn(empty.unk_label, empty) empty.update(list("abcde")) self.assertIn(empty.unk_label, empty) def test_lookup(self): self.assertEqual(self.vocab.lookup("a"), "a") self.assertEqual(self.vocab.lookup("c"), "") def test_lookup_iterables(self): self.assertEqual(self.vocab.lookup(["a", "b"]), ("a", "b")) self.assertEqual(self.vocab.lookup(("a", "b")), ("a", "b")) self.assertEqual(self.vocab.lookup(("a", "c")), ("a", "")) self.assertEqual( self.vocab.lookup(map(str, range(3))), ("", "", "") ) def test_lookup_empty_iterables(self): self.assertEqual(self.vocab.lookup(()), ()) self.assertEqual(self.vocab.lookup([]), ()) self.assertEqual(self.vocab.lookup(iter([])), ()) self.assertEqual(self.vocab.lookup(n for n in range(0, 0)), ()) def test_lookup_recursive(self): self.assertEqual( self.vocab.lookup([["a", "b"], ["a", "c"]]), (("a", "b"), ("a", "")) ) self.assertEqual(self.vocab.lookup([["a", "b"], "c"]), (("a", "b"), "")) self.assertEqual(self.vocab.lookup([[[[["a", "b"]]]]]), ((((("a", "b"),),),),)) def test_lookup_None(self): with self.assertRaises(TypeError): self.vocab.lookup(None) with self.assertRaises(TypeError): list(self.vocab.lookup([None, None])) def test_lookup_int(self): with self.assertRaises(TypeError): self.vocab.lookup(1) with self.assertRaises(TypeError): list(self.vocab.lookup([1, 2])) def test_lookup_empty_str(self): self.assertEqual(self.vocab.lookup(""), "") def test_eqality(self): v1 = Vocabulary(["a", "b", "c"], unk_cutoff=1) v2 = Vocabulary(["a", "b", "c"], unk_cutoff=1) v3 = Vocabulary(["a", "b", "c"], unk_cutoff=1, unk_label="blah") v4 = Vocabulary(["a", "b"], unk_cutoff=1) self.assertEqual(v1, v2) self.assertNotEqual(v1, v3) self.assertNotEqual(v1, v4) def test_str(self): self.assertEqual( str(self.vocab), (""), ) def test_creation_with_counter(self): self.assertEqual( self.vocab, Vocabulary( Counter( ["z", "a", "b", "c", "f", "d", "e", "g", "a", "d", "b", "e", "w"] ), unk_cutoff=2, ), )