You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
202 lines
5.3 KiB
Python
202 lines
5.3 KiB
Python
import pytest
|
|
from pytest import raises as assert_raises
|
|
import numpy as np
|
|
import collections
|
|
from scipy.cluster.hierarchy import DisjointSet
|
|
import string
|
|
|
|
|
|
def generate_random_token():
|
|
k = len(string.ascii_letters)
|
|
tokens = list(np.arange(k, dtype=int))
|
|
tokens += list(np.arange(k, dtype=float))
|
|
tokens += list(string.ascii_letters)
|
|
tokens += [None for i in range(k)]
|
|
rng = np.random.RandomState(seed=0)
|
|
|
|
while 1:
|
|
size = rng.randint(1, 3)
|
|
element = rng.choice(tokens, size)
|
|
if size == 1:
|
|
yield element[0]
|
|
else:
|
|
yield tuple(element)
|
|
|
|
|
|
def get_elements(n):
|
|
# OrderedDict is deterministic without difficulty of comparing numpy ints
|
|
elements = collections.OrderedDict()
|
|
for element in generate_random_token():
|
|
if element not in elements:
|
|
elements[element] = len(elements)
|
|
if len(elements) >= n:
|
|
break
|
|
return list(elements.keys())
|
|
|
|
|
|
def test_init():
|
|
n = 10
|
|
elements = get_elements(n)
|
|
dis = DisjointSet(elements)
|
|
assert dis.n_subsets == n
|
|
assert list(dis) == elements
|
|
|
|
|
|
def test_len():
|
|
n = 10
|
|
elements = get_elements(n)
|
|
dis = DisjointSet(elements)
|
|
assert len(dis) == n
|
|
|
|
dis.add("dummy")
|
|
assert len(dis) == n + 1
|
|
|
|
|
|
@pytest.mark.parametrize("n", [10, 100])
|
|
def test_contains(n):
|
|
elements = get_elements(n)
|
|
dis = DisjointSet(elements)
|
|
for x in elements:
|
|
assert x in dis
|
|
|
|
assert "dummy" not in dis
|
|
|
|
|
|
@pytest.mark.parametrize("n", [10, 100])
|
|
def test_add(n):
|
|
elements = get_elements(n)
|
|
dis1 = DisjointSet(elements)
|
|
|
|
dis2 = DisjointSet()
|
|
for i, x in enumerate(elements):
|
|
dis2.add(x)
|
|
assert len(dis2) == i + 1
|
|
|
|
# test idempotency by adding element again
|
|
dis2.add(x)
|
|
assert len(dis2) == i + 1
|
|
|
|
assert list(dis1) == list(dis2)
|
|
|
|
|
|
def test_element_not_present():
|
|
elements = get_elements(n=10)
|
|
dis = DisjointSet(elements)
|
|
|
|
with assert_raises(KeyError):
|
|
dis["dummy"]
|
|
|
|
with assert_raises(KeyError):
|
|
dis.merge(elements[0], "dummy")
|
|
|
|
with assert_raises(KeyError):
|
|
dis.connected(elements[0], "dummy")
|
|
|
|
|
|
@pytest.mark.parametrize("direction", ["forwards", "backwards"])
|
|
@pytest.mark.parametrize("n", [10, 100])
|
|
def test_linear_union_sequence(n, direction):
|
|
elements = get_elements(n)
|
|
dis = DisjointSet(elements)
|
|
assert elements == list(dis)
|
|
|
|
indices = list(range(n - 1))
|
|
if direction == "backwards":
|
|
indices = indices[::-1]
|
|
|
|
for it, i in enumerate(indices):
|
|
assert not dis.connected(elements[i], elements[i + 1])
|
|
assert dis.merge(elements[i], elements[i + 1])
|
|
assert dis.connected(elements[i], elements[i + 1])
|
|
assert dis.n_subsets == n - 1 - it
|
|
|
|
roots = [dis[i] for i in elements]
|
|
if direction == "forwards":
|
|
assert all(elements[0] == r for r in roots)
|
|
else:
|
|
assert all(elements[-2] == r for r in roots)
|
|
assert not dis.merge(elements[0], elements[-1])
|
|
|
|
|
|
@pytest.mark.parametrize("n", [10, 100])
|
|
def test_self_unions(n):
|
|
elements = get_elements(n)
|
|
dis = DisjointSet(elements)
|
|
|
|
for x in elements:
|
|
assert dis.connected(x, x)
|
|
assert not dis.merge(x, x)
|
|
assert dis.connected(x, x)
|
|
assert dis.n_subsets == len(elements)
|
|
|
|
assert elements == list(dis)
|
|
roots = [dis[x] for x in elements]
|
|
assert elements == roots
|
|
|
|
|
|
@pytest.mark.parametrize("order", ["ab", "ba"])
|
|
@pytest.mark.parametrize("n", [10, 100])
|
|
def test_equal_size_ordering(n, order):
|
|
elements = get_elements(n)
|
|
dis = DisjointSet(elements)
|
|
|
|
rng = np.random.RandomState(seed=0)
|
|
indices = np.arange(n)
|
|
rng.shuffle(indices)
|
|
|
|
for i in range(0, len(indices), 2):
|
|
a, b = elements[indices[i]], elements[indices[i + 1]]
|
|
if order == "ab":
|
|
assert dis.merge(a, b)
|
|
else:
|
|
assert dis.merge(b, a)
|
|
|
|
expected = elements[min(indices[i], indices[i + 1])]
|
|
assert dis[a] == expected
|
|
assert dis[b] == expected
|
|
|
|
|
|
@pytest.mark.parametrize("kmax", [5, 10])
|
|
def test_binary_tree(kmax):
|
|
n = 2**kmax
|
|
elements = get_elements(n)
|
|
dis = DisjointSet(elements)
|
|
rng = np.random.RandomState(seed=0)
|
|
|
|
for k in 2**np.arange(kmax):
|
|
for i in range(0, n, 2 * k):
|
|
r1, r2 = rng.randint(0, k, size=2)
|
|
a, b = elements[i + r1], elements[i + k + r2]
|
|
assert not dis.connected(a, b)
|
|
assert dis.merge(a, b)
|
|
assert dis.connected(a, b)
|
|
|
|
assert elements == list(dis)
|
|
roots = [dis[i] for i in elements]
|
|
expected_indices = np.arange(n) - np.arange(n) % (2 * k)
|
|
expected = [elements[i] for i in expected_indices]
|
|
assert roots == expected
|
|
|
|
|
|
@pytest.mark.parametrize("n", [10, 100])
|
|
def test_subsets(n):
|
|
elements = get_elements(n)
|
|
dis = DisjointSet(elements)
|
|
|
|
rng = np.random.RandomState(seed=0)
|
|
for i, j in rng.randint(0, n, (n, 2)):
|
|
x = elements[i]
|
|
y = elements[j]
|
|
|
|
expected = {element for element in dis if {dis[element]} == {dis[x]}}
|
|
assert expected == dis.subset(x)
|
|
|
|
expected = {dis[element]: set() for element in dis}
|
|
for element in dis:
|
|
expected[dis[element]].add(element)
|
|
expected = list(expected.values())
|
|
assert expected == dis.subsets()
|
|
|
|
dis.merge(x, y)
|
|
assert dis.subset(x) == dis.subset(y)
|