You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

369 lines
12 KiB
Python

"""Indexing mixin for sparse matrix classes.
"""
import numpy as np
from .sputils import isintlike
try:
INT_TYPES = (int, long, np.integer)
except NameError:
# long is not defined in Python3
INT_TYPES = (int, np.integer)
def _broadcast_arrays(a, b):
"""
Same as np.broadcast_arrays(a, b) but old writeability rules.
NumPy >= 1.17.0 transitions broadcast_arrays to return
read-only arrays. Set writeability explicitly to avoid warnings.
Retain the old writeability rules, as our Cython code assumes
the old behavior.
"""
x, y = np.broadcast_arrays(a, b)
x.flags.writeable = a.flags.writeable
y.flags.writeable = b.flags.writeable
return x, y
class IndexMixin(object):
"""
This class provides common dispatching and validation logic for indexing.
"""
def __getitem__(self, key):
row, col = self._validate_indices(key)
# Dispatch to specialized methods.
if isinstance(row, INT_TYPES):
if isinstance(col, INT_TYPES):
return self._get_intXint(row, col)
elif isinstance(col, slice):
return self._get_intXslice(row, col)
elif col.ndim == 1:
return self._get_intXarray(row, col)
raise IndexError('index results in >2 dimensions')
elif isinstance(row, slice):
if isinstance(col, INT_TYPES):
return self._get_sliceXint(row, col)
elif isinstance(col, slice):
if row == slice(None) and row == col:
return self.copy()
return self._get_sliceXslice(row, col)
elif col.ndim == 1:
return self._get_sliceXarray(row, col)
raise IndexError('index results in >2 dimensions')
elif row.ndim == 1:
if isinstance(col, INT_TYPES):
return self._get_arrayXint(row, col)
elif isinstance(col, slice):
return self._get_arrayXslice(row, col)
else: # row.ndim == 2
if isinstance(col, INT_TYPES):
return self._get_arrayXint(row, col)
elif isinstance(col, slice):
raise IndexError('index results in >2 dimensions')
elif row.shape[1] == 1 and (col.ndim == 1 or col.shape[0] == 1):
# special case for outer indexing
return self._get_columnXarray(row[:,0], col.ravel())
# The only remaining case is inner (fancy) indexing
row, col = _broadcast_arrays(row, col)
if row.shape != col.shape:
raise IndexError('number of row and column indices differ')
if row.size == 0:
return self.__class__(np.atleast_2d(row).shape, dtype=self.dtype)
return self._get_arrayXarray(row, col)
def __setitem__(self, key, x):
row, col = self._validate_indices(key)
if isinstance(row, INT_TYPES) and isinstance(col, INT_TYPES):
x = np.asarray(x, dtype=self.dtype)
if x.size != 1:
raise ValueError('Trying to assign a sequence to an item')
self._set_intXint(row, col, x.flat[0])
return
if isinstance(row, slice):
row = np.arange(*row.indices(self.shape[0]))[:, None]
else:
row = np.atleast_1d(row)
if isinstance(col, slice):
col = np.arange(*col.indices(self.shape[1]))[None, :]
if row.ndim == 1:
row = row[:, None]
else:
col = np.atleast_1d(col)
i, j = _broadcast_arrays(row, col)
if i.shape != j.shape:
raise IndexError('number of row and column indices differ')
from .base import isspmatrix
if isspmatrix(x):
if i.ndim == 1:
# Inner indexing, so treat them like row vectors.
i = i[None]
j = j[None]
broadcast_row = x.shape[0] == 1 and i.shape[0] != 1
broadcast_col = x.shape[1] == 1 and i.shape[1] != 1
if not ((broadcast_row or x.shape[0] == i.shape[0]) and
(broadcast_col or x.shape[1] == i.shape[1])):
raise ValueError('shape mismatch in assignment')
if x.shape[0] == 0 or x.shape[1] == 0:
return
x = x.tocoo(copy=True)
x.sum_duplicates()
self._set_arrayXarray_sparse(i, j, x)
else:
# Make x and i into the same shape
x = np.asarray(x, dtype=self.dtype)
if x.squeeze().shape != i.squeeze().shape:
x = np.broadcast_to(x, i.shape)
if x.size == 0:
return
x = x.reshape(i.shape)
self._set_arrayXarray(i, j, x)
def _validate_indices(self, key):
M, N = self.shape
row, col = _unpack_index(key)
if isintlike(row):
row = int(row)
if row < -M or row >= M:
raise IndexError('row index (%d) out of range' % row)
if row < 0:
row += M
elif not isinstance(row, slice):
row = self._asindices(row, M)
if isintlike(col):
col = int(col)
if col < -N or col >= N:
raise IndexError('column index (%d) out of range' % col)
if col < 0:
col += N
elif not isinstance(col, slice):
col = self._asindices(col, N)
return row, col
def _asindices(self, idx, length):
"""Convert `idx` to a valid index for an axis with a given length.
Subclasses that need special validation can override this method.
"""
try:
x = np.asarray(idx)
except (ValueError, TypeError, MemoryError) as e:
raise IndexError('invalid index') from e
if x.ndim not in (1, 2):
raise IndexError('Index dimension must be <= 2')
if x.size == 0:
return x
# Check bounds
max_indx = x.max()
if max_indx >= length:
raise IndexError('index (%d) out of range' % max_indx)
min_indx = x.min()
if min_indx < 0:
if min_indx < -length:
raise IndexError('index (%d) out of range' % min_indx)
if x is idx or not x.flags.owndata:
x = x.copy()
x[x < 0] += length
return x
def getrow(self, i):
"""Return a copy of row i of the matrix, as a (1 x n) row vector.
"""
M, N = self.shape
i = int(i)
if i < -M or i >= M:
raise IndexError('index (%d) out of range' % i)
if i < 0:
i += M
return self._get_intXslice(i, slice(None))
def getcol(self, i):
"""Return a copy of column i of the matrix, as a (m x 1) column vector.
"""
M, N = self.shape
i = int(i)
if i < -N or i >= N:
raise IndexError('index (%d) out of range' % i)
if i < 0:
i += N
return self._get_sliceXint(slice(None), i)
def _get_intXint(self, row, col):
raise NotImplementedError()
def _get_intXarray(self, row, col):
raise NotImplementedError()
def _get_intXslice(self, row, col):
raise NotImplementedError()
def _get_sliceXint(self, row, col):
raise NotImplementedError()
def _get_sliceXslice(self, row, col):
raise NotImplementedError()
def _get_sliceXarray(self, row, col):
raise NotImplementedError()
def _get_arrayXint(self, row, col):
raise NotImplementedError()
def _get_arrayXslice(self, row, col):
raise NotImplementedError()
def _get_columnXarray(self, row, col):
raise NotImplementedError()
def _get_arrayXarray(self, row, col):
raise NotImplementedError()
def _set_intXint(self, row, col, x):
raise NotImplementedError()
def _set_arrayXarray(self, row, col, x):
raise NotImplementedError()
def _set_arrayXarray_sparse(self, row, col, x):
# Fall back to densifying x
x = np.asarray(x.toarray(), dtype=self.dtype)
x, _ = _broadcast_arrays(x, row)
self._set_arrayXarray(row, col, x)
def _unpack_index(index):
""" Parse index. Always return a tuple of the form (row, col).
Valid type for row/col is integer, slice, or array of integers.
"""
# First, check if indexing with single boolean matrix.
from .base import spmatrix, isspmatrix
if (isinstance(index, (spmatrix, np.ndarray)) and
index.ndim == 2 and index.dtype.kind == 'b'):
return index.nonzero()
# Parse any ellipses.
index = _check_ellipsis(index)
# Next, parse the tuple or object
if isinstance(index, tuple):
if len(index) == 2:
row, col = index
elif len(index) == 1:
row, col = index[0], slice(None)
else:
raise IndexError('invalid number of indices')
else:
idx = _compatible_boolean_index(index)
if idx is None:
row, col = index, slice(None)
elif idx.ndim < 2:
return _boolean_index_to_array(idx), slice(None)
elif idx.ndim == 2:
return idx.nonzero()
# Next, check for validity and transform the index as needed.
if isspmatrix(row) or isspmatrix(col):
# Supporting sparse boolean indexing with both row and col does
# not work because spmatrix.ndim is always 2.
raise IndexError(
'Indexing with sparse matrices is not supported '
'except boolean indexing where matrix and index '
'are equal shapes.')
bool_row = _compatible_boolean_index(row)
bool_col = _compatible_boolean_index(col)
if bool_row is not None:
row = _boolean_index_to_array(bool_row)
if bool_col is not None:
col = _boolean_index_to_array(bool_col)
return row, col
def _check_ellipsis(index):
"""Process indices with Ellipsis. Returns modified index."""
if index is Ellipsis:
return (slice(None), slice(None))
if not isinstance(index, tuple):
return index
# TODO: Deprecate this multiple-ellipsis handling,
# as numpy no longer supports it.
# Find first ellipsis.
for j, v in enumerate(index):
if v is Ellipsis:
first_ellipsis = j
break
else:
return index
# Try to expand it using shortcuts for common cases
if len(index) == 1:
return (slice(None), slice(None))
if len(index) == 2:
if first_ellipsis == 0:
if index[1] is Ellipsis:
return (slice(None), slice(None))
return (slice(None), index[1])
return (index[0], slice(None))
# Expand it using a general-purpose algorithm
tail = []
for v in index[first_ellipsis+1:]:
if v is not Ellipsis:
tail.append(v)
nd = first_ellipsis + len(tail)
nslice = max(0, 2 - nd)
return index[:first_ellipsis] + (slice(None),)*nslice + tuple(tail)
def _maybe_bool_ndarray(idx):
"""Returns a compatible array if elements are boolean.
"""
idx = np.asanyarray(idx)
if idx.dtype.kind == 'b':
return idx
return None
def _first_element_bool(idx, max_dim=2):
"""Returns True if first element of the incompatible
array type is boolean.
"""
if max_dim < 1:
return None
try:
first = next(iter(idx), None)
except TypeError:
return None
if isinstance(first, bool):
return True
return _first_element_bool(first, max_dim-1)
def _compatible_boolean_index(idx):
"""Returns a boolean index array that can be converted to
integer array. Returns None if no such array exists.
"""
# Presence of attribute `ndim` indicates a compatible array type.
if hasattr(idx, 'ndim') or _first_element_bool(idx):
return _maybe_bool_ndarray(idx)
return None
def _boolean_index_to_array(idx):
if idx.ndim > 1:
raise IndexError('invalid index shape')
return np.where(idx)[0]