""" Disjoint set data structure """ import collections class DisjointSet: """ Disjoint set data structure for incremental connectivity queries. .. versionadded: 1.6.0 Attributes ---------- n_subsets : int The number of subsets. Methods ------- add merge connected subset subsets __getitem__ Notes ----- This class implements the disjoint set [1]_, also known as the *union-find* or *merge-find* data structure. The *find* operation (implemented in `__getitem__`) implements the *path halving* variant. The *merge* method implements the *merge by size* variant. References ---------- .. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure Examples -------- >>> from scipy.cluster.hierarchy import DisjointSet Initialize a disjoint set: >>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b']) Merge some subsets: >>> disjoint_set.merge(1, 2) True >>> disjoint_set.merge(3, 'a') True >>> disjoint_set.merge('a', 'b') True >>> disjoint_set.merge('b', 'b') False Find root elements: >>> disjoint_set[2] 1 >>> disjoint_set['b'] 3 Test connectivity: >>> disjoint_set.connected(1, 2) True >>> disjoint_set.connected(1, 'b') False List elements in disjoint set: >>> list(disjoint_set) [1, 2, 3, 'a', 'b'] Get the subset containing 'a': >>> disjoint_set.subset('a') {'a', 3, 'b'} Get all subsets in the disjoint set: >>> disjoint_set.subsets() [{1, 2}, {'a', 3, 'b'}] """ def __init__(self, elements=None): self.n_subsets = 0 self._sizes = {} self._parents = {} # _nbrs is a circular linked list which links connected elements. self._nbrs = {} # _indices tracks the element insertion order - OrderedDict is used to # ensure correct ordering in `__iter__`. self._indices = collections.OrderedDict() if elements is not None: for x in elements: self.add(x) def __iter__(self): """Returns an iterator of the elements in the disjoint set. Elements are ordered by insertion order. """ return iter(self._indices) def __len__(self): return len(self._indices) def __contains__(self, x): return x in self._indices def __getitem__(self, x): """Find the root element of `x`. Parameters ---------- x : hashable object Input element. Returns ------- root : hashable object Root element of `x`. """ if x not in self._indices: raise KeyError(x) # find by "path halving" parents = self._parents while self._indices[x] != self._indices[parents[x]]: parents[x] = parents[parents[x]] x = parents[x] return x def add(self, x): """Add element `x` to disjoint set """ if x in self._indices: return self._sizes[x] = 1 self._parents[x] = x self._nbrs[x] = x self._indices[x] = len(self._indices) self.n_subsets += 1 def merge(self, x, y): """Merge the subsets of `x` and `y`. The smaller subset (the child) is merged into the larger subset (the parent). If the subsets are of equal size, the root element which was first inserted into the disjoint set is selected as the parent. Parameters ---------- x, y : hashable object Elements to merge. Returns ------- merged : bool True if `x` and `y` were in disjoint sets, False otherwise. """ xr = self[x] yr = self[y] if self._indices[xr] == self._indices[yr]: return False sizes = self._sizes if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]): xr, yr = yr, xr self._parents[yr] = xr self._sizes[xr] += self._sizes[yr] self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr] self.n_subsets -= 1 return True def connected(self, x, y): """Test whether `x` and `y` are in the same subset. Parameters ---------- x, y : hashable object Elements to test. Returns ------- result : bool True if `x` and `y` are in the same set, False otherwise. """ return self._indices[self[x]] == self._indices[self[y]] def subset(self, x): """Get the subset containing `x`. Parameters ---------- x : hashable object Input element. Returns ------- result : set Subset containing `x`. """ if x not in self._indices: raise KeyError(x) result = [x] nxt = self._nbrs[x] while self._indices[nxt] != self._indices[x]: result.append(nxt) nxt = self._nbrs[nxt] return set(result) def subsets(self): """Get all the subsets in the disjoint set. Returns ------- result : set Subsets in the disjoint set. """ result = [] visited = set() for x in self: if x not in visited: xset = self.subset(x) visited.update(xset) result.append(xset) return result