You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
5.3 KiB
Python

import pytest
from pytest import raises as assert_raises
import numpy as np
import collections
from scipy.cluster.hierarchy import DisjointSet
import string
def generate_random_token():
k = len(string.ascii_letters)
tokens = list(np.arange(k, dtype=int))
tokens += list(np.arange(k, dtype=float))
tokens += list(string.ascii_letters)
tokens += [None for i in range(k)]
rng = np.random.RandomState(seed=0)
while 1:
size = rng.randint(1, 3)
element = rng.choice(tokens, size)
if size == 1:
yield element[0]
else:
yield tuple(element)
def get_elements(n):
# OrderedDict is deterministic without difficulty of comparing numpy ints
elements = collections.OrderedDict()
for element in generate_random_token():
if element not in elements:
elements[element] = len(elements)
if len(elements) >= n:
break
return list(elements.keys())
def test_init():
n = 10
elements = get_elements(n)
dis = DisjointSet(elements)
assert dis.n_subsets == n
assert list(dis) == elements
def test_len():
n = 10
elements = get_elements(n)
dis = DisjointSet(elements)
assert len(dis) == n
dis.add("dummy")
assert len(dis) == n + 1
@pytest.mark.parametrize("n", [10, 100])
def test_contains(n):
elements = get_elements(n)
dis = DisjointSet(elements)
for x in elements:
assert x in dis
assert "dummy" not in dis
@pytest.mark.parametrize("n", [10, 100])
def test_add(n):
elements = get_elements(n)
dis1 = DisjointSet(elements)
dis2 = DisjointSet()
for i, x in enumerate(elements):
dis2.add(x)
assert len(dis2) == i + 1
# test idempotency by adding element again
dis2.add(x)
assert len(dis2) == i + 1
assert list(dis1) == list(dis2)
def test_element_not_present():
elements = get_elements(n=10)
dis = DisjointSet(elements)
with assert_raises(KeyError):
dis["dummy"]
with assert_raises(KeyError):
dis.merge(elements[0], "dummy")
with assert_raises(KeyError):
dis.connected(elements[0], "dummy")
@pytest.mark.parametrize("direction", ["forwards", "backwards"])
@pytest.mark.parametrize("n", [10, 100])
def test_linear_union_sequence(n, direction):
elements = get_elements(n)
dis = DisjointSet(elements)
assert elements == list(dis)
indices = list(range(n - 1))
if direction == "backwards":
indices = indices[::-1]
for it, i in enumerate(indices):
assert not dis.connected(elements[i], elements[i + 1])
assert dis.merge(elements[i], elements[i + 1])
assert dis.connected(elements[i], elements[i + 1])
assert dis.n_subsets == n - 1 - it
roots = [dis[i] for i in elements]
if direction == "forwards":
assert all(elements[0] == r for r in roots)
else:
assert all(elements[-2] == r for r in roots)
assert not dis.merge(elements[0], elements[-1])
@pytest.mark.parametrize("n", [10, 100])
def test_self_unions(n):
elements = get_elements(n)
dis = DisjointSet(elements)
for x in elements:
assert dis.connected(x, x)
assert not dis.merge(x, x)
assert dis.connected(x, x)
assert dis.n_subsets == len(elements)
assert elements == list(dis)
roots = [dis[x] for x in elements]
assert elements == roots
@pytest.mark.parametrize("order", ["ab", "ba"])
@pytest.mark.parametrize("n", [10, 100])
def test_equal_size_ordering(n, order):
elements = get_elements(n)
dis = DisjointSet(elements)
rng = np.random.RandomState(seed=0)
indices = np.arange(n)
rng.shuffle(indices)
for i in range(0, len(indices), 2):
a, b = elements[indices[i]], elements[indices[i + 1]]
if order == "ab":
assert dis.merge(a, b)
else:
assert dis.merge(b, a)
expected = elements[min(indices[i], indices[i + 1])]
assert dis[a] == expected
assert dis[b] == expected
@pytest.mark.parametrize("kmax", [5, 10])
def test_binary_tree(kmax):
n = 2**kmax
elements = get_elements(n)
dis = DisjointSet(elements)
rng = np.random.RandomState(seed=0)
for k in 2**np.arange(kmax):
for i in range(0, n, 2 * k):
r1, r2 = rng.randint(0, k, size=2)
a, b = elements[i + r1], elements[i + k + r2]
assert not dis.connected(a, b)
assert dis.merge(a, b)
assert dis.connected(a, b)
assert elements == list(dis)
roots = [dis[i] for i in elements]
expected_indices = np.arange(n) - np.arange(n) % (2 * k)
expected = [elements[i] for i in expected_indices]
assert roots == expected
@pytest.mark.parametrize("n", [10, 100])
def test_subsets(n):
elements = get_elements(n)
dis = DisjointSet(elements)
rng = np.random.RandomState(seed=0)
for i, j in rng.randint(0, n, (n, 2)):
x = elements[i]
y = elements[j]
expected = {element for element in dis if {dis[element]} == {dis[x]}}
assert expected == dis.subset(x)
expected = {dis[element]: set() for element in dis}
for element in dis:
expected[dis[element]].add(element)
expected = list(expected.values())
assert expected == dis.subsets()
dis.merge(x, y)
assert dis.subset(x) == dis.subset(y)