You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

381 lines
12 KiB
Python

5 years ago
# Natural Language Toolkit: Interface to Weka Classsifiers
#
# Copyright (C) 2001-2020 NLTK Project
5 years ago
# Author: Edward Loper <edloper@gmail.com>
# URL: <http://nltk.org/>
# For license information, see LICENSE.TXT
"""
Classifiers that make use of the external 'Weka' package.
"""
5 years ago
import time
import tempfile
import os
import subprocess
import re
import zipfile
from sys import stdin
from nltk.probability import DictionaryProbDist
from nltk.internals import java, config_java
from nltk.classify.api import ClassifierI
_weka_classpath = None
_weka_search = [
".",
"/usr/share/weka",
"/usr/local/share/weka",
"/usr/lib/weka",
"/usr/local/lib/weka",
5 years ago
]
def config_weka(classpath=None):
global _weka_classpath
# Make sure java's configured first.
config_java()
if classpath is not None:
_weka_classpath = classpath
if _weka_classpath is None:
searchpath = _weka_search
if "WEKAHOME" in os.environ:
searchpath.insert(0, os.environ["WEKAHOME"])
5 years ago
for path in searchpath:
if os.path.exists(os.path.join(path, "weka.jar")):
_weka_classpath = os.path.join(path, "weka.jar")
5 years ago
version = _check_weka_version(_weka_classpath)
if version:
print(
("[Found Weka: %s (version %s)]" % (_weka_classpath, version))
5 years ago
)
else:
print("[Found Weka: %s]" % _weka_classpath)
5 years ago
_check_weka_version(_weka_classpath)
if _weka_classpath is None:
raise LookupError(
"Unable to find weka.jar! Use config_weka() "
"or set the WEKAHOME environment variable. "
"For more information about Weka, please see "
"http://www.cs.waikato.ac.nz/ml/weka/"
5 years ago
)
def _check_weka_version(jar):
try:
zf = zipfile.ZipFile(jar)
except (SystemExit, KeyboardInterrupt):
raise
except:
return None
try:
try:
return zf.read("weka/core/version.txt")
5 years ago
except KeyError:
return None
finally:
zf.close()
class WekaClassifier(ClassifierI):
def __init__(self, formatter, model_filename):
self._formatter = formatter
self._model = model_filename
def prob_classify_many(self, featuresets):
return self._classify_many(featuresets, ["-p", "0", "-distribution"])
5 years ago
def classify_many(self, featuresets):
return self._classify_many(featuresets, ["-p", "0"])
5 years ago
def _classify_many(self, featuresets, options):
# Make sure we can find java & weka.
config_weka()
temp_dir = tempfile.mkdtemp()
try:
# Write the test data file.
test_filename = os.path.join(temp_dir, "test.arff")
5 years ago
self._formatter.write(test_filename, featuresets)
# Call weka to classify the data.
cmd = [
"weka.classifiers.bayes.NaiveBayes",
"-l",
5 years ago
self._model,
"-T",
5 years ago
test_filename,
] + options
(stdout, stderr) = java(
cmd,
classpath=_weka_classpath,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Check if something went wrong:
if stderr and not stdout:
if "Illegal options: -distribution" in stderr:
5 years ago
raise ValueError(
"The installed version of weka does "
"not support probability distribution "
"output."
5 years ago
)
else:
raise ValueError("Weka failed to generate output:\n%s" % stderr)
5 years ago
# Parse weka's output.
return self.parse_weka_output(stdout.decode(stdin.encoding).split("\n"))
5 years ago
finally:
for f in os.listdir(temp_dir):
os.remove(os.path.join(temp_dir, f))
os.rmdir(temp_dir)
def parse_weka_distribution(self, s):
probs = [float(v) for v in re.split("[*,]+", s) if v.strip()]
5 years ago
probs = dict(zip(self._formatter.labels(), probs))
return DictionaryProbDist(probs)
def parse_weka_output(self, lines):
# Strip unwanted text from stdout
for i, line in enumerate(lines):
if line.strip().startswith("inst#"):
lines = lines[i:]
break
if lines[0].split() == ["inst#", "actual", "predicted", "error", "prediction"]:
return [line.split()[2].split(":")[1] for line in lines[1:] if line.strip()]
5 years ago
elif lines[0].split() == [
"inst#",
"actual",
"predicted",
"error",
"distribution",
5 years ago
]:
return [
self.parse_weka_distribution(line.split()[-1])
for line in lines[1:]
if line.strip()
]
# is this safe:?
elif re.match(r"^0 \w+ [01]\.[0-9]* \?\s*$", lines[0]):
5 years ago
return [line.split()[1] for line in lines if line.strip()]
else:
for line in lines[:10]:
print(line)
raise ValueError(
"Unhandled output format -- your version "
"of weka may not be supported.\n"
" Header: %s" % lines[0]
5 years ago
)
# [xx] full list of classifiers (some may be abstract?):
# ADTree, AODE, BayesNet, ComplementNaiveBayes, ConjunctiveRule,
# DecisionStump, DecisionTable, HyperPipes, IB1, IBk, Id3, J48,
# JRip, KStar, LBR, LeastMedSq, LinearRegression, LMT, Logistic,
# LogisticBase, M5Base, MultilayerPerceptron,
# MultipleClassifiersCombiner, NaiveBayes, NaiveBayesMultinomial,
# NaiveBayesSimple, NBTree, NNge, OneR, PaceRegression, PART,
# PreConstructedLinearModel, Prism, RandomForest,
# RandomizableClassifier, RandomTree, RBFNetwork, REPTree, Ridor,
# RuleNode, SimpleLinearRegression, SimpleLogistic,
# SingleClassifierEnhancer, SMO, SMOreg, UserClassifier, VFI,
# VotedPerceptron, Winnow, ZeroR
_CLASSIFIER_CLASS = {
"naivebayes": "weka.classifiers.bayes.NaiveBayes",
"C4.5": "weka.classifiers.trees.J48",
"log_regression": "weka.classifiers.functions.Logistic",
"svm": "weka.classifiers.functions.SMO",
"kstar": "weka.classifiers.lazy.KStar",
"ripper": "weka.classifiers.rules.JRip",
5 years ago
}
@classmethod
def train(
cls,
model_filename,
featuresets,
classifier="naivebayes",
5 years ago
options=[],
quiet=True,
):
# Make sure we can find java & weka.
config_weka()
# Build an ARFF formatter.
formatter = ARFF_Formatter.from_train(featuresets)
temp_dir = tempfile.mkdtemp()
try:
# Write the training data file.
train_filename = os.path.join(temp_dir, "train.arff")
5 years ago
formatter.write(train_filename, featuresets)
if classifier in cls._CLASSIFIER_CLASS:
javaclass = cls._CLASSIFIER_CLASS[classifier]
elif classifier in cls._CLASSIFIER_CLASS.values():
javaclass = classifier
else:
raise ValueError("Unknown classifier %s" % classifier)
5 years ago
# Train the weka model.
cmd = [javaclass, "-d", model_filename, "-t", train_filename]
5 years ago
cmd += list(options)
if quiet:
stdout = subprocess.PIPE
else:
stdout = None
java(cmd, classpath=_weka_classpath, stdout=stdout)
# Return the new classifier.
return WekaClassifier(formatter, model_filename)
finally:
for f in os.listdir(temp_dir):
os.remove(os.path.join(temp_dir, f))
os.rmdir(temp_dir)
class ARFF_Formatter:
"""
Converts featuresets and labeled featuresets to ARFF-formatted
strings, appropriate for input into Weka.
Features and classes can be specified manually in the constructor, or may
be determined from data using ``from_train``.
"""
def __init__(self, labels, features):
"""
:param labels: A list of all class labels that can be generated.
:param features: A list of feature specifications, where
each feature specification is a tuple (fname, ftype);
and ftype is an ARFF type string such as NUMERIC or
STRING.
"""
self._labels = labels
self._features = features
def format(self, tokens):
"""Returns a string representation of ARFF output for the given data."""
return self.header_section() + self.data_section(tokens)
def labels(self):
"""Returns the list of classes."""
return list(self._labels)
def write(self, outfile, tokens):
"""Writes ARFF data to a file for the given data."""
if not hasattr(outfile, "write"):
outfile = open(outfile, "w")
5 years ago
outfile.write(self.format(tokens))
outfile.close()
@staticmethod
def from_train(tokens):
"""
Constructs an ARFF_Formatter instance with class labels and feature
types determined from the given data. Handles boolean, numeric and
string (note: not nominal) types.
"""
# Find the set of all attested labels.
labels = set(label for (tok, label) in tokens)
# Determine the types of all features.
features = {}
for tok, label in tokens:
for (fname, fval) in tok.items():
if issubclass(type(fval), bool):
ftype = "{True, False}"
elif issubclass(type(fval), (int, float, bool)):
ftype = "NUMERIC"
elif issubclass(type(fval), str):
ftype = "STRING"
5 years ago
elif fval is None:
continue # can't tell the type.
else:
raise ValueError("Unsupported value type %r" % ftype)
5 years ago
if features.get(fname, ftype) != ftype:
raise ValueError("Inconsistent type for %s" % fname)
5 years ago
features[fname] = ftype
features = sorted(features.items())
return ARFF_Formatter(labels, features)
def header_section(self):
"""Returns an ARFF header as a string."""
# Header comment.
s = (
"% Weka ARFF file\n"
+ "% Generated automatically by NLTK\n"
+ "%% %s\n\n" % time.ctime()
5 years ago
)
# Relation name
s += "@RELATION rel\n\n"
5 years ago
# Input attribute specifications
for fname, ftype in self._features:
s += "@ATTRIBUTE %-30r %s\n" % (fname, ftype)
5 years ago
# Label attribute specification
s += "@ATTRIBUTE %-30r {%s}\n" % ("-label-", ",".join(self._labels))
5 years ago
return s
def data_section(self, tokens, labeled=None):
"""
Returns the ARFF data section for the given data.
:param tokens: a list of featuresets (dicts) or labelled featuresets
which are tuples (featureset, label).
:param labeled: Indicates whether the given tokens are labeled
or not. If None, then the tokens will be assumed to be
labeled if the first token's value is a tuple or list.
"""
# Check if the tokens are labeled or unlabeled. If unlabeled,
# then use 'None'
if labeled is None:
labeled = tokens and isinstance(tokens[0], (tuple, list))
if not labeled:
tokens = [(tok, None) for tok in tokens]
# Data section
s = "\n@DATA\n"
5 years ago
for (tok, label) in tokens:
for fname, ftype in self._features:
s += "%s," % self._fmt_arff_val(tok.get(fname))
s += "%s\n" % self._fmt_arff_val(label)
5 years ago
return s
def _fmt_arff_val(self, fval):
if fval is None:
return "?"
elif isinstance(fval, (bool, int)):
return "%s" % fval
5 years ago
elif isinstance(fval, float):
return "%r" % fval
5 years ago
else:
return "%r" % fval
5 years ago
if __name__ == "__main__":
5 years ago
from nltk.classify.util import names_demo, binary_names_demo_features
def make_classifier(featuresets):
return WekaClassifier.train("/tmp/name.model", featuresets, "C4.5")
5 years ago
classifier = names_demo(make_classifier, binary_names_demo_features)