You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
7.0 KiB
Python

5 years ago
"""
Plotting of string "category" data: ``plot(['d', 'f', 'a'], [1, 2, 3])`` will
plot three points with x-axis values of 'd', 'f', 'a'.
See :doc:`/gallery/lines_bars_and_markers/categorical_variables` for an
example.
The module uses Matplotlib's `matplotlib.units` mechanism to convert from
strings to integers and provides a tick locator, a tick formatter, and the
`.UnitData` class that creates and stores the string-to-integer mapping.
"""
from collections import OrderedDict
import dateutil.parser
import itertools
import logging
import numpy as np
from matplotlib import cbook, ticker, units
_log = logging.getLogger(__name__)
class StrCategoryConverter(units.ConversionInterface):
@staticmethod
def convert(value, unit, axis):
"""
Convert strings in *value* to floats using mapping information stored
in the *unit* object.
Parameters
----------
value : str or iterable
Value or list of values to be converted.
unit : `.UnitData`
An object mapping strings to integers.
axis : `~matplotlib.axis.Axis`
The axis on which the converted value is plotted.
.. note:: *axis* is unused.
Returns
-------
mapped_value : float or ndarray[float]
"""
if unit is None:
raise ValueError(
'Missing category information for StrCategoryConverter; '
'this might be caused by unintendedly mixing categorical and '
'numeric data')
# dtype = object preserves numerical pass throughs
values = np.atleast_1d(np.array(value, dtype=object))
# pass through sequence of non binary numbers
if all(units.ConversionInterface.is_numlike(v)
and not isinstance(v, (str, bytes))
for v in values):
return np.asarray(values, dtype=float)
# force an update so it also does type checking
unit.update(values)
return np.vectorize(unit._mapping.__getitem__, otypes=[float])(values)
@staticmethod
def axisinfo(unit, axis):
"""
Set the default axis ticks and labels.
Parameters
----------
unit : `.UnitData`
object string unit information for value
axis : `~matplotlib.axis.Axis`
axis for which information is being set
Returns
-------
axisinfo : `~matplotlib.units.AxisInfo`
Information to support default tick labeling
.. note: axis is not used
"""
# locator and formatter take mapping dict because
# args need to be pass by reference for updates
majloc = StrCategoryLocator(unit._mapping)
majfmt = StrCategoryFormatter(unit._mapping)
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
@staticmethod
def default_units(data, axis):
"""
Set and update the `~matplotlib.axis.Axis` units.
Parameters
----------
data : str or iterable of str
axis : `~matplotlib.axis.Axis`
axis on which the data is plotted
Returns
-------
class : `.UnitData`
object storing string to integer mapping
"""
# the conversion call stack is default_units -> axis_info -> convert
if axis.units is None:
axis.set_units(UnitData(data))
else:
axis.units.update(data)
return axis.units
class StrCategoryLocator(ticker.Locator):
"""Tick at every integer mapping of the string data."""
def __init__(self, units_mapping):
"""
Parameters
-----------
units_mapping : Dict[str, int]
"""
self._units = units_mapping
def __call__(self):
return list(self._units.values())
def tick_values(self, vmin, vmax):
return self()
class StrCategoryFormatter(ticker.Formatter):
"""String representation of the data at every tick."""
def __init__(self, units_mapping):
"""
Parameters
----------
units_mapping : Dict[Str, int]
"""
self._units = units_mapping
def __call__(self, x, pos=None):
"""
Return the category label string for tick val *x*.
The position *pos* is ignored.
"""
return self.format_ticks([x])[0]
def format_ticks(self, values):
r_mapping = {v: self._text(k) for k, v in self._units.items()}
return [r_mapping.get(round(val), '') for val in values]
@staticmethod
def _text(value):
"""Convert text values into utf-8 or ascii strings."""
if isinstance(value, bytes):
value = value.decode(encoding='utf-8')
elif not isinstance(value, str):
value = str(value)
return value
class UnitData:
def __init__(self, data=None):
"""
Create mapping between unique categorical values and integer ids.
Parameters
----------
data : iterable
sequence of string values
"""
self._mapping = OrderedDict()
self._counter = itertools.count()
if data is not None:
self.update(data)
@staticmethod
def _str_is_convertible(val):
"""
Helper method to check whether a string can be parsed as float or date.
"""
try:
float(val)
except ValueError:
try:
dateutil.parser.parse(val)
except (ValueError, TypeError):
# TypeError if dateutil >= 2.8.1 else ValueError
return False
return True
def update(self, data):
"""
Map new values to integer identifiers.
Parameters
----------
data : iterable
sequence of string values
Raises
------
TypeError
If the value in data is not a string, unicode, bytes type
"""
data = np.atleast_1d(np.array(data, dtype=object))
# check if convertible to number:
convertible = True
for val in OrderedDict.fromkeys(data):
# OrderedDict just iterates over unique values in data.
cbook._check_isinstance((str, bytes), value=val)
if convertible:
# this will only be called so long as convertible is True.
convertible = self._str_is_convertible(val)
if val not in self._mapping:
self._mapping[val] = next(self._counter)
if convertible:
_log.info('Using categorical units to plot a list of strings '
'that are all parsable as floats or dates. If these '
'strings should be plotted as numbers, cast to the '
'appropriate data type before plotting.')
# Register the converter with Matplotlib's unit framework
units.registry[str] = StrCategoryConverter()
units.registry[np.str_] = StrCategoryConverter()
units.registry[bytes] = StrCategoryConverter()
units.registry[np.bytes_] = StrCategoryConverter()