You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
5.5 KiB
Python

"""
Disjoint set data structure
"""
import collections
class DisjointSet:
""" Disjoint set data structure for incremental connectivity queries.
.. versionadded: 1.6.0
Attributes
----------
n_subsets : int
The number of subsets.
Methods
-------
add
merge
connected
subset
subsets
__getitem__
Notes
-----
This class implements the disjoint set [1]_, also known as the *union-find*
or *merge-find* data structure. The *find* operation (implemented in
`__getitem__`) implements the *path halving* variant. The *merge* method
implements the *merge by size* variant.
References
----------
.. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Examples
--------
>>> from scipy.cluster.hierarchy import DisjointSet
Initialize a disjoint set:
>>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b'])
Merge some subsets:
>>> disjoint_set.merge(1, 2)
True
>>> disjoint_set.merge(3, 'a')
True
>>> disjoint_set.merge('a', 'b')
True
>>> disjoint_set.merge('b', 'b')
False
Find root elements:
>>> disjoint_set[2]
1
>>> disjoint_set['b']
3
Test connectivity:
>>> disjoint_set.connected(1, 2)
True
>>> disjoint_set.connected(1, 'b')
False
List elements in disjoint set:
>>> list(disjoint_set)
[1, 2, 3, 'a', 'b']
Get the subset containing 'a':
>>> disjoint_set.subset('a')
{'a', 3, 'b'}
Get all subsets in the disjoint set:
>>> disjoint_set.subsets()
[{1, 2}, {'a', 3, 'b'}]
"""
def __init__(self, elements=None):
self.n_subsets = 0
self._sizes = {}
self._parents = {}
# _nbrs is a circular linked list which links connected elements.
self._nbrs = {}
# _indices tracks the element insertion order - OrderedDict is used to
# ensure correct ordering in `__iter__`.
self._indices = collections.OrderedDict()
if elements is not None:
for x in elements:
self.add(x)
def __iter__(self):
"""Returns an iterator of the elements in the disjoint set.
Elements are ordered by insertion order.
"""
return iter(self._indices)
def __len__(self):
return len(self._indices)
def __contains__(self, x):
return x in self._indices
def __getitem__(self, x):
"""Find the root element of `x`.
Parameters
----------
x : hashable object
Input element.
Returns
-------
root : hashable object
Root element of `x`.
"""
if x not in self._indices:
raise KeyError(x)
# find by "path halving"
parents = self._parents
while self._indices[x] != self._indices[parents[x]]:
parents[x] = parents[parents[x]]
x = parents[x]
return x
def add(self, x):
"""Add element `x` to disjoint set
"""
if x in self._indices:
return
self._sizes[x] = 1
self._parents[x] = x
self._nbrs[x] = x
self._indices[x] = len(self._indices)
self.n_subsets += 1
def merge(self, x, y):
"""Merge the subsets of `x` and `y`.
The smaller subset (the child) is merged into the larger subset (the
parent). If the subsets are of equal size, the root element which was
first inserted into the disjoint set is selected as the parent.
Parameters
----------
x, y : hashable object
Elements to merge.
Returns
-------
merged : bool
True if `x` and `y` were in disjoint sets, False otherwise.
"""
xr = self[x]
yr = self[y]
if self._indices[xr] == self._indices[yr]:
return False
sizes = self._sizes
if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]):
xr, yr = yr, xr
self._parents[yr] = xr
self._sizes[xr] += self._sizes[yr]
self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr]
self.n_subsets -= 1
return True
def connected(self, x, y):
"""Test whether `x` and `y` are in the same subset.
Parameters
----------
x, y : hashable object
Elements to test.
Returns
-------
result : bool
True if `x` and `y` are in the same set, False otherwise.
"""
return self._indices[self[x]] == self._indices[self[y]]
def subset(self, x):
"""Get the subset containing `x`.
Parameters
----------
x : hashable object
Input element.
Returns
-------
result : set
Subset containing `x`.
"""
if x not in self._indices:
raise KeyError(x)
result = [x]
nxt = self._nbrs[x]
while self._indices[nxt] != self._indices[x]:
result.append(nxt)
nxt = self._nbrs[nxt]
return set(result)
def subsets(self):
"""Get all the subsets in the disjoint set.
Returns
-------
result : set
Subsets in the disjoint set.
"""
result = []
visited = set()
for x in self:
if x not in visited:
xset = self.subset(x)
visited.update(xset)
result.append(xset)
return result