You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
4.2 KiB
Python

# -*- coding: utf-8 -*-
# Natural Language Toolkit: Language Model Unit Tests
#
# Copyright (C) 2001-2019 NLTK Project
# Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
# URL: <http://nltk.org/>
# For license information, see LICENSE.TXT
import unittest
import six
from nltk import FreqDist
from nltk.lm import NgramCounter
from nltk.util import everygrams
class NgramCounterTests(unittest.TestCase):
"""Tests for NgramCounter that only involve lookup, no modification."""
@classmethod
def setUpClass(cls):
text = [list("abcd"), list("egdbe")]
cls.trigram_counter = NgramCounter(
(everygrams(sent, max_len=3) for sent in text)
)
cls.bigram_counter = NgramCounter(
(everygrams(sent, max_len=2) for sent in text)
)
def test_N(self):
self.assertEqual(self.bigram_counter.N(), 16)
self.assertEqual(self.trigram_counter.N(), 21)
def test_counter_len_changes_with_lookup(self):
self.assertEqual(len(self.bigram_counter), 2)
_ = self.bigram_counter[50]
self.assertEqual(len(self.bigram_counter), 3)
def test_ngram_order_access_unigrams(self):
self.assertEqual(self.bigram_counter[1], self.bigram_counter.unigrams)
def test_ngram_conditional_freqdist(self):
expected_trigram_contexts = [
("a", "b"),
("b", "c"),
("e", "g"),
("g", "d"),
("d", "b"),
]
expected_bigram_contexts = [("a",), ("b",), ("d",), ("e",), ("c",), ("g",)]
bigrams = self.trigram_counter[2]
trigrams = self.trigram_counter[3]
six.assertCountEqual(self, expected_bigram_contexts, bigrams.conditions())
six.assertCountEqual(self, expected_trigram_contexts, trigrams.conditions())
def test_bigram_counts_seen_ngrams(self):
b_given_a_count = 1
unk_given_b_count = 1
self.assertEqual(b_given_a_count, self.bigram_counter[["a"]]["b"])
self.assertEqual(unk_given_b_count, self.bigram_counter[["b"]]["c"])
def test_bigram_counts_unseen_ngrams(self):
z_given_b_count = 0
self.assertEqual(z_given_b_count, self.bigram_counter[["b"]]["z"])
def test_unigram_counts_seen_words(self):
expected_count_b = 2
self.assertEqual(expected_count_b, self.bigram_counter["b"])
def test_unigram_counts_completely_unseen_words(self):
unseen_count = 0
self.assertEqual(unseen_count, self.bigram_counter["z"])
class NgramCounterTrainingTests(unittest.TestCase):
def setUp(self):
self.counter = NgramCounter()
def test_empty_string(self):
test = NgramCounter("")
self.assertNotIn(2, test)
self.assertEqual(test[1], FreqDist())
def test_empty_list(self):
test = NgramCounter([])
self.assertNotIn(2, test)
self.assertEqual(test[1], FreqDist())
def test_None(self):
test = NgramCounter(None)
self.assertNotIn(2, test)
self.assertEqual(test[1], FreqDist())
def test_train_on_unigrams(self):
words = list("abcd")
counter = NgramCounter([[(w,) for w in words]])
self.assertFalse(bool(counter[3]))
self.assertFalse(bool(counter[2]))
six.assertCountEqual(self, words, counter[1].keys())
def test_train_on_illegal_sentences(self):
str_sent = ["Check", "this", "out", "!"]
list_sent = [["Check", "this"], ["this", "out"], ["out", "!"]]
with self.assertRaises(TypeError):
NgramCounter([str_sent])
with self.assertRaises(TypeError):
NgramCounter([list_sent])
def test_train_on_bigrams(self):
bigram_sent = [("a", "b"), ("c", "d")]
counter = NgramCounter([bigram_sent])
self.assertFalse(bool(counter[3]))
def test_train_on_mix(self):
mixed_sent = [("a", "b"), ("c", "d"), ("e", "f", "g"), ("h",)]
counter = NgramCounter([mixed_sent])
unigrams = ["h"]
bigram_contexts = [("a",), ("c",)]
trigram_contexts = [("e", "f")]
six.assertCountEqual(self, unigrams, counter[1].keys())
six.assertCountEqual(self, bigram_contexts, counter[2].keys())
six.assertCountEqual(self, trigram_contexts, counter[3].keys())