You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
253 lines
8.3 KiB
Python
253 lines
8.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Natural Language Toolkit: Language Models
|
|
#
|
|
# Copyright (C) 2001-2019 NLTK Project
|
|
# Authors: Ilia Kurenkov <ilia.kurenkov@gmail.com>
|
|
# URL: <http://nltk.org/>
|
|
# For license information, see LICENSE.TXT
|
|
"""Language Model Interface."""
|
|
from __future__ import division, unicode_literals
|
|
|
|
import random
|
|
from abc import ABCMeta, abstractmethod
|
|
from bisect import bisect
|
|
|
|
from six import add_metaclass
|
|
|
|
from nltk.lm.counter import NgramCounter
|
|
from nltk.lm.util import log_base2
|
|
from nltk.lm.vocabulary import Vocabulary
|
|
|
|
try:
|
|
from itertools import accumulate
|
|
except ImportError:
|
|
import operator
|
|
|
|
def accumulate(iterable, func=operator.add):
|
|
"""Return running totals"""
|
|
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15
|
|
# accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
|
|
it = iter(iterable)
|
|
try:
|
|
total = next(it)
|
|
except StopIteration:
|
|
return
|
|
yield total
|
|
for element in it:
|
|
total = func(total, element)
|
|
yield total
|
|
|
|
|
|
@add_metaclass(ABCMeta)
|
|
class Smoothing(object):
|
|
"""Ngram Smoothing Interface
|
|
|
|
Implements Chen & Goodman 1995's idea that all smoothing algorithms have
|
|
certain features in common. This should ideally allow smoothing algoritms to
|
|
work both with Backoff and Interpolation.
|
|
"""
|
|
|
|
def __init__(self, vocabulary, counter):
|
|
"""
|
|
:param vocabulary: The Ngram vocabulary object.
|
|
:type vocabulary: nltk.lm.vocab.Vocabulary
|
|
:param counter: The counts of the vocabulary items.
|
|
:type counter: nltk.lm.counter.NgramCounter
|
|
"""
|
|
self.vocab = vocabulary
|
|
self.counts = counter
|
|
|
|
@abstractmethod
|
|
def unigram_score(self, word):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def alpha_gamma(self, word, context):
|
|
raise NotImplementedError()
|
|
|
|
|
|
def _mean(items):
|
|
"""Return average (aka mean) for sequence of items."""
|
|
return sum(items) / len(items)
|
|
|
|
|
|
def _random_generator(seed_or_generator):
|
|
if isinstance(seed_or_generator, random.Random):
|
|
return seed_or_generator
|
|
return random.Random(seed_or_generator)
|
|
|
|
|
|
def _weighted_choice(population, weights, random_generator=None):
|
|
"""Like random.choice, but with weights.
|
|
|
|
Heavily inspired by python 3.6 `random.choices`.
|
|
"""
|
|
if not population:
|
|
raise ValueError("Can't choose from empty population")
|
|
if len(population) != len(weights):
|
|
raise ValueError("The number of weights does not match the population")
|
|
cum_weights = list(accumulate(weights))
|
|
total = cum_weights[-1]
|
|
threshold = random_generator.random()
|
|
return population[bisect(cum_weights, total * threshold)]
|
|
|
|
|
|
@add_metaclass(ABCMeta)
|
|
class LanguageModel(object):
|
|
"""ABC for Language Models.
|
|
|
|
Cannot be directly instantiated itself.
|
|
|
|
"""
|
|
|
|
def __init__(self, order, vocabulary=None, counter=None):
|
|
"""Creates new LanguageModel.
|
|
|
|
:param vocabulary: If provided, this vocabulary will be used instead
|
|
of creating a new one when training.
|
|
:type vocabulary: `nltk.lm.Vocabulary` or None
|
|
:param counter: If provided, use this object to count ngrams.
|
|
:type vocabulary: `nltk.lm.NgramCounter` or None
|
|
:param ngrams_fn: If given, defines how sentences in training text are turned to ngram
|
|
sequences.
|
|
:type ngrams_fn: function or None
|
|
:param pad_fn: If given, defines how senteces in training text are padded.
|
|
:type pad_fn: function or None
|
|
|
|
"""
|
|
self.order = order
|
|
self.vocab = Vocabulary() if vocabulary is None else vocabulary
|
|
self.counts = NgramCounter() if counter is None else counter
|
|
|
|
def fit(self, text, vocabulary_text=None):
|
|
"""Trains the model on a text.
|
|
|
|
:param text: Training text as a sequence of sentences.
|
|
|
|
"""
|
|
if not self.vocab:
|
|
if vocabulary_text is None:
|
|
raise ValueError(
|
|
"Cannot fit without a vocabulary or text to " "create it from."
|
|
)
|
|
self.vocab.update(vocabulary_text)
|
|
self.counts.update(self.vocab.lookup(sent) for sent in text)
|
|
|
|
def score(self, word, context=None):
|
|
"""Masks out of vocab (OOV) words and computes their model score.
|
|
|
|
For model-specific logic of calculating scores, see the `unmasked_score`
|
|
method.
|
|
"""
|
|
return self.unmasked_score(
|
|
self.vocab.lookup(word), self.vocab.lookup(context) if context else None
|
|
)
|
|
|
|
@abstractmethod
|
|
def unmasked_score(self, word, context=None):
|
|
"""Score a word given some optional context.
|
|
|
|
Concrete models are expected to provide an implementation.
|
|
Note that this method does not mask its arguments with the OOV label.
|
|
Use the `score` method for that.
|
|
|
|
:param str word: Word for which we want the score
|
|
:param tuple(str) context: Context the word is in.
|
|
If `None`, compute unigram score.
|
|
:param context: tuple(str) or None
|
|
:rtype: float
|
|
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def logscore(self, word, context=None):
|
|
"""Evaluate the log score of this word in this context.
|
|
|
|
The arguments are the same as for `score` and `unmasked_score`.
|
|
|
|
"""
|
|
return log_base2(self.score(word, context))
|
|
|
|
def context_counts(self, context):
|
|
"""Helper method for retrieving counts for a given context.
|
|
|
|
Assumes context has been checked and oov words in it masked.
|
|
:type context: tuple(str) or None
|
|
|
|
"""
|
|
return (
|
|
self.counts[len(context) + 1][context] if context else self.counts.unigrams
|
|
)
|
|
|
|
def entropy(self, text_ngrams):
|
|
"""Calculate cross-entropy of model for given evaluation text.
|
|
|
|
:param Iterable(tuple(str)) text_ngrams: A sequence of ngram tuples.
|
|
:rtype: float
|
|
|
|
"""
|
|
return -1 * _mean(
|
|
[self.logscore(ngram[-1], ngram[:-1]) for ngram in text_ngrams]
|
|
)
|
|
|
|
def perplexity(self, text_ngrams):
|
|
"""Calculates the perplexity of the given text.
|
|
|
|
This is simply 2 ** cross-entropy for the text, so the arguments are the same.
|
|
|
|
"""
|
|
return pow(2.0, self.entropy(text_ngrams))
|
|
|
|
def generate(self, num_words=1, text_seed=None, random_seed=None):
|
|
"""Generate words from the model.
|
|
|
|
:param int num_words: How many words to generate. By default 1.
|
|
:param text_seed: Generation can be conditioned on preceding context.
|
|
:param random_seed: A random seed or an instance of `random.Random`. If provided,
|
|
makes the random sampling part of generation reproducible.
|
|
:return: One (str) word or a list of words generated from model.
|
|
|
|
Examples:
|
|
|
|
>>> from nltk.lm import MLE
|
|
>>> lm = MLE(2)
|
|
>>> lm.fit([[("a", "b"), ("b", "c")]], vocabulary_text=['a', 'b', 'c'])
|
|
>>> lm.fit([[("a",), ("b",), ("c",)]])
|
|
>>> lm.generate(random_seed=3)
|
|
'a'
|
|
>>> lm.generate(text_seed=['a'])
|
|
'b'
|
|
|
|
"""
|
|
text_seed = [] if text_seed is None else list(text_seed)
|
|
random_generator = _random_generator(random_seed)
|
|
# base recursion case
|
|
if num_words == 1:
|
|
context = (
|
|
text_seed[-self.order + 1 :]
|
|
if len(text_seed) >= self.order
|
|
else text_seed
|
|
)
|
|
samples = self.context_counts(self.vocab.lookup(context))
|
|
while context and not samples:
|
|
context = context[1:] if len(context) > 1 else []
|
|
samples = self.context_counts(self.vocab.lookup(context))
|
|
# sorting achieves two things:
|
|
# - reproducible randomness when sampling
|
|
# - turning Mapping into Sequence which _weighted_choice expects
|
|
samples = sorted(samples)
|
|
return _weighted_choice(
|
|
samples, tuple(self.score(w, context) for w in samples), random_generator
|
|
)
|
|
# build up text one word at a time
|
|
generated = []
|
|
for _ in range(num_words):
|
|
generated.append(
|
|
self.generate(
|
|
num_words=1,
|
|
text_seed=text_seed + generated,
|
|
random_seed=random_generator,
|
|
)
|
|
)
|
|
return generated
|