# -*- coding: utf-8 -*- # Natural Language Toolkit: Language Model Unit Tests # # Copyright (C) 2001-2019 NLTK Project # Author: Ilia Kurenkov # URL: # For license information, see LICENSE.TXT import unittest import six from nltk import FreqDist from nltk.lm import NgramCounter from nltk.util import everygrams class NgramCounterTests(unittest.TestCase): """Tests for NgramCounter that only involve lookup, no modification.""" @classmethod def setUpClass(cls): text = [list("abcd"), list("egdbe")] cls.trigram_counter = NgramCounter( (everygrams(sent, max_len=3) for sent in text) ) cls.bigram_counter = NgramCounter( (everygrams(sent, max_len=2) for sent in text) ) def test_N(self): self.assertEqual(self.bigram_counter.N(), 16) self.assertEqual(self.trigram_counter.N(), 21) def test_counter_len_changes_with_lookup(self): self.assertEqual(len(self.bigram_counter), 2) _ = self.bigram_counter[50] self.assertEqual(len(self.bigram_counter), 3) def test_ngram_order_access_unigrams(self): self.assertEqual(self.bigram_counter[1], self.bigram_counter.unigrams) def test_ngram_conditional_freqdist(self): expected_trigram_contexts = [ ("a", "b"), ("b", "c"), ("e", "g"), ("g", "d"), ("d", "b"), ] expected_bigram_contexts = [("a",), ("b",), ("d",), ("e",), ("c",), ("g",)] bigrams = self.trigram_counter[2] trigrams = self.trigram_counter[3] six.assertCountEqual(self, expected_bigram_contexts, bigrams.conditions()) six.assertCountEqual(self, expected_trigram_contexts, trigrams.conditions()) def test_bigram_counts_seen_ngrams(self): b_given_a_count = 1 unk_given_b_count = 1 self.assertEqual(b_given_a_count, self.bigram_counter[["a"]]["b"]) self.assertEqual(unk_given_b_count, self.bigram_counter[["b"]]["c"]) def test_bigram_counts_unseen_ngrams(self): z_given_b_count = 0 self.assertEqual(z_given_b_count, self.bigram_counter[["b"]]["z"]) def test_unigram_counts_seen_words(self): expected_count_b = 2 self.assertEqual(expected_count_b, self.bigram_counter["b"]) def test_unigram_counts_completely_unseen_words(self): unseen_count = 0 self.assertEqual(unseen_count, self.bigram_counter["z"]) class NgramCounterTrainingTests(unittest.TestCase): def setUp(self): self.counter = NgramCounter() def test_empty_string(self): test = NgramCounter("") self.assertNotIn(2, test) self.assertEqual(test[1], FreqDist()) def test_empty_list(self): test = NgramCounter([]) self.assertNotIn(2, test) self.assertEqual(test[1], FreqDist()) def test_None(self): test = NgramCounter(None) self.assertNotIn(2, test) self.assertEqual(test[1], FreqDist()) def test_train_on_unigrams(self): words = list("abcd") counter = NgramCounter([[(w,) for w in words]]) self.assertFalse(bool(counter[3])) self.assertFalse(bool(counter[2])) six.assertCountEqual(self, words, counter[1].keys()) def test_train_on_illegal_sentences(self): str_sent = ["Check", "this", "out", "!"] list_sent = [["Check", "this"], ["this", "out"], ["out", "!"]] with self.assertRaises(TypeError): NgramCounter([str_sent]) with self.assertRaises(TypeError): NgramCounter([list_sent]) def test_train_on_bigrams(self): bigram_sent = [("a", "b"), ("c", "d")] counter = NgramCounter([bigram_sent]) self.assertFalse(bool(counter[3])) def test_train_on_mix(self): mixed_sent = [("a", "b"), ("c", "d"), ("e", "f", "g"), ("h",)] counter = NgramCounter([mixed_sent]) unigrams = ["h"] bigram_contexts = [("a",), ("c",)] trigram_contexts = [("e", "f")] six.assertCountEqual(self, unigrams, counter[1].keys()) six.assertCountEqual(self, bigram_contexts, counter[2].keys()) six.assertCountEqual(self, trigram_contexts, counter[3].keys())