You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
7.4 KiB
Python
249 lines
7.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Natural Language Toolkit
|
|
#
|
|
# Copyright (C) 2001-2019 NLTK Project
|
|
# Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
|
|
# URL: <http://nltk.org/>
|
|
# For license information, see LICENSE.TXT
|
|
"""Language Model Vocabulary"""
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
import sys
|
|
from collections import Counter, Iterable
|
|
from itertools import chain
|
|
|
|
from nltk import compat
|
|
|
|
try:
|
|
# Python >= 3.4
|
|
from functools import singledispatch
|
|
except ImportError:
|
|
# Python < 3.4
|
|
from singledispatch import singledispatch
|
|
|
|
|
|
@singledispatch
|
|
def _dispatched_lookup(words, vocab):
|
|
raise TypeError(
|
|
"Unsupported type for looking up in vocabulary: {0}".format(type(words))
|
|
)
|
|
|
|
|
|
@_dispatched_lookup.register(Iterable)
|
|
def _(words, vocab):
|
|
"""Look up a sequence of words in the vocabulary.
|
|
|
|
Returns an iterator over looked up words.
|
|
|
|
"""
|
|
return tuple(_dispatched_lookup(w, vocab) for w in words)
|
|
|
|
|
|
try:
|
|
# Python 2 unicode + str type
|
|
basestring
|
|
except NameError:
|
|
# Python 3 unicode + str type
|
|
basestring = str
|
|
|
|
|
|
@_dispatched_lookup.register(basestring)
|
|
def _string_lookup(word, vocab):
|
|
"""Looks up one word in the vocabulary."""
|
|
return word if word in vocab else vocab.unk_label
|
|
|
|
|
|
@compat.python_2_unicode_compatible
|
|
class Vocabulary(object):
|
|
"""Stores language model vocabulary.
|
|
|
|
Satisfies two common language modeling requirements for a vocabulary:
|
|
- When checking membership and calculating its size, filters items
|
|
by comparing their counts to a cutoff value.
|
|
- Adds a special "unknown" token which unseen words are mapped to.
|
|
|
|
>>> words = ['a', 'c', '-', 'd', 'c', 'a', 'b', 'r', 'a', 'c', 'd']
|
|
>>> from nltk.lm import Vocabulary
|
|
>>> vocab = Vocabulary(words, unk_cutoff=2)
|
|
|
|
Tokens with counts greater than or equal to the cutoff value will
|
|
be considered part of the vocabulary.
|
|
|
|
>>> vocab['c']
|
|
3
|
|
>>> 'c' in vocab
|
|
True
|
|
>>> vocab['d']
|
|
2
|
|
>>> 'd' in vocab
|
|
True
|
|
|
|
Tokens with frequency counts less than the cutoff value will be considered not
|
|
part of the vocabulary even though their entries in the count dictionary are
|
|
preserved.
|
|
|
|
>>> vocab['b']
|
|
1
|
|
>>> 'b' in vocab
|
|
False
|
|
>>> vocab['aliens']
|
|
0
|
|
>>> 'aliens' in vocab
|
|
False
|
|
|
|
Keeping the count entries for seen words allows us to change the cutoff value
|
|
without having to recalculate the counts.
|
|
|
|
>>> vocab2 = Vocabulary(vocab.counts, unk_cutoff=1)
|
|
>>> "b" in vocab2
|
|
True
|
|
|
|
The cutoff value influences not only membership checking but also the result of
|
|
getting the size of the vocabulary using the built-in `len`.
|
|
Note that while the number of keys in the vocabulary's counter stays the same,
|
|
the items in the vocabulary differ depending on the cutoff.
|
|
We use `sorted` to demonstrate because it keeps the order consistent.
|
|
|
|
>>> sorted(vocab2.counts)
|
|
['-', 'a', 'b', 'c', 'd', 'r']
|
|
>>> sorted(vocab2)
|
|
['-', '<UNK>', 'a', 'b', 'c', 'd', 'r']
|
|
>>> sorted(vocab.counts)
|
|
['-', 'a', 'b', 'c', 'd', 'r']
|
|
>>> sorted(vocab)
|
|
['<UNK>', 'a', 'c', 'd']
|
|
|
|
In addition to items it gets populated with, the vocabulary stores a special
|
|
token that stands in for so-called "unknown" items. By default it's "<UNK>".
|
|
|
|
>>> "<UNK>" in vocab
|
|
True
|
|
|
|
We can look up words in a vocabulary using its `lookup` method.
|
|
"Unseen" words (with counts less than cutoff) are looked up as the unknown label.
|
|
If given one word (a string) as an input, this method will return a string.
|
|
|
|
>>> vocab.lookup("a")
|
|
'a'
|
|
>>> vocab.lookup("aliens")
|
|
'<UNK>'
|
|
|
|
If given a sequence, it will return an tuple of the looked up words.
|
|
|
|
>>> vocab.lookup(["p", 'a', 'r', 'd', 'b', 'c'])
|
|
('<UNK>', 'a', '<UNK>', 'd', '<UNK>', 'c')
|
|
|
|
It's possible to update the counts after the vocabulary has been created.
|
|
The interface follows that of `collections.Counter`.
|
|
|
|
>>> vocab['b']
|
|
1
|
|
>>> vocab.update(["b", "b", "c"])
|
|
>>> vocab['b']
|
|
3
|
|
"""
|
|
|
|
def __init__(self, counts=None, unk_cutoff=1, unk_label="<UNK>"):
|
|
"""Create a new Vocabulary.
|
|
|
|
:param counts: Optional iterable or `collections.Counter` instance to
|
|
pre-seed the Vocabulary. In case it is iterable, counts
|
|
are calculated.
|
|
:param int unk_cutoff: Words that occur less frequently than this value
|
|
are not considered part of the vocabulary.
|
|
:param unk_label: Label for marking words not part of vocabulary.
|
|
|
|
"""
|
|
if isinstance(counts, Counter):
|
|
self.counts = counts
|
|
else:
|
|
self.counts = Counter()
|
|
if isinstance(counts, Iterable):
|
|
self.counts.update(counts)
|
|
self.unk_label = unk_label
|
|
if unk_cutoff < 1:
|
|
raise ValueError(
|
|
"Cutoff value cannot be less than 1. Got: {0}".format(unk_cutoff)
|
|
)
|
|
self._cutoff = unk_cutoff
|
|
|
|
@property
|
|
def cutoff(self):
|
|
"""Cutoff value.
|
|
|
|
Items with count below this value are not considered part of vocabulary.
|
|
|
|
"""
|
|
return self._cutoff
|
|
|
|
def update(self, *counter_args, **counter_kwargs):
|
|
"""Update vocabulary counts.
|
|
|
|
Wraps `collections.Counter.update` method.
|
|
|
|
"""
|
|
self.counts.update(*counter_args, **counter_kwargs)
|
|
|
|
def lookup(self, words):
|
|
"""Look up one or more words in the vocabulary.
|
|
|
|
If passed one word as a string will return that word or `self.unk_label`.
|
|
Otherwise will assume it was passed a sequence of words, will try to look
|
|
each of them up and return an iterator over the looked up words.
|
|
|
|
:param words: Word(s) to look up.
|
|
:type words: Iterable(str) or str
|
|
:rtype: generator(str) or str
|
|
:raises: TypeError for types other than strings or iterables
|
|
|
|
>>> from nltk.lm import Vocabulary
|
|
>>> vocab = Vocabulary(["a", "b", "c", "a", "b"], unk_cutoff=2)
|
|
>>> vocab.lookup("a")
|
|
'a'
|
|
>>> vocab.lookup("aliens")
|
|
'<UNK>'
|
|
>>> vocab.lookup(["a", "b", "c", ["x", "b"]])
|
|
('a', 'b', '<UNK>', ('<UNK>', 'b'))
|
|
|
|
"""
|
|
return _dispatched_lookup(words, self)
|
|
|
|
def __getitem__(self, item):
|
|
return self._cutoff if item == self.unk_label else self.counts[item]
|
|
|
|
def __contains__(self, item):
|
|
"""Only consider items with counts GE to cutoff as being in the
|
|
vocabulary."""
|
|
return self[item] >= self.cutoff
|
|
|
|
def __iter__(self):
|
|
"""Building on membership check define how to iterate over
|
|
vocabulary."""
|
|
return chain(
|
|
(item for item in self.counts if item in self),
|
|
[self.unk_label] if self.counts else [],
|
|
)
|
|
|
|
def __len__(self):
|
|
"""Computing size of vocabulary reflects the cutoff."""
|
|
return sum(1 for _ in self)
|
|
|
|
def __eq__(self, other):
|
|
return (
|
|
self.unk_label == other.unk_label
|
|
and self.cutoff == other.cutoff
|
|
and self.counts == other.counts
|
|
)
|
|
|
|
if sys.version_info[0] == 2:
|
|
# see https://stackoverflow.com/a/35781654/4501212
|
|
def __ne__(self, other):
|
|
equal = self.__eq__(other)
|
|
return equal if equal is NotImplemented else not equal
|
|
|
|
def __str__(self):
|
|
return "<{0} with cutoff={1} unk_label='{2}' and {3} items>".format(
|
|
self.__class__.__name__, self.cutoff, self.unk_label, len(self)
|
|
)
|