You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.6 KiB
Python

from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from builtins import str, bytes, dict, int
import os
import sys
sys.path.insert(0, os.path.join("..", ".."))
from pattern.vector import Document, Model, NB
from pattern.db import Datasheet
# Naive Bayes is one of the oldest classifiers,
# but is is still popular because it is fast for models
# that have many documents and many features.
# It is outperformed by KNN and SVM, but useful as a baseline for tests.
# We'll test it with a corpus of spam e-mail messages,
# included in the test suite, stored as a CSV-file.
# The corpus contains mostly technical e-mail from developer mailing lists.
data = os.path.join(os.path.dirname(__file__), "..", "..", "test", "corpora", "spam-apache.csv")
data = Datasheet.load(data)
documents = []
for score, message in data:
document = Document(message, type=int(score) > 0)
documents.append(document)
m = Model(documents)
print("number of documents:", len(m))
print("number of words:", len(m.vector))
print("number of words (average):", sum(len(d.features) for d in m.documents) / float(len(m)))
print()
# Train Naive Bayes on all documents.
# Each document has a type: True for actual e-mail, False for spam.
# This results in a "binary" classifier that either answers True or False
# for unknown documents.
classifier = NB()
for document in m:
classifier.train(document)
# We can now ask it questions about unknown e-mails:
print(classifier.classify("win money")) # False: most likely spam.
print(classifier.classify("fix bug")) # True: most likely a real message.
print()
print(classifier.classify("customer")) # False: people don't talk like this on developer lists...
print(classifier.classify("guys")) # True: because most likely everyone knows everyone.
print()
# To test the accuracy of a classifier,
# we typically use 10-fold cross validation.
# This means that 10 individual tests are performed,
# each with 90% of the corpus as training data and 10% as testing data.
from pattern.vector import k_fold_cv
print(k_fold_cv(NB, documents=m, folds=10))
# This yields 5 scores: (Accuracy, Precision, Recall, F-score, standard deviation).
# Accuracy in itself is not very useful,
# since some spam may have been regarded as real messages (false positives),
# and some real messages may have been regarded as spam (false negatives).
# Precision = how accurately false positives are discarded,
# Recall = how accurately false negatives are discarded.
# F-score = harmonic mean of precision and recall.
# stdev = folds' variation from average F-score.