reconnect moved files to git repo
This commit is contained in:
@ -0,0 +1,80 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from scipy.sparse.csgraph import connected_components
|
||||
|
||||
from sklearn.metrics.pairwise import pairwise_distances
|
||||
from sklearn.neighbors import kneighbors_graph
|
||||
from sklearn.utils.graph import _fix_connected_components
|
||||
|
||||
|
||||
def test_fix_connected_components():
|
||||
# Test that _fix_connected_components reduces the number of component to 1.
|
||||
X = np.array([0, 1, 2, 5, 6, 7])[:, None]
|
||||
graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
|
||||
|
||||
n_connected_components, labels = connected_components(graph)
|
||||
assert n_connected_components > 1
|
||||
|
||||
graph = _fix_connected_components(X, graph, n_connected_components, labels)
|
||||
|
||||
n_connected_components, labels = connected_components(graph)
|
||||
assert n_connected_components == 1
|
||||
|
||||
|
||||
def test_fix_connected_components_precomputed():
|
||||
# Test that _fix_connected_components accepts precomputed distance matrix.
|
||||
X = np.array([0, 1, 2, 5, 6, 7])[:, None]
|
||||
graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
|
||||
|
||||
n_connected_components, labels = connected_components(graph)
|
||||
assert n_connected_components > 1
|
||||
|
||||
distances = pairwise_distances(X)
|
||||
graph = _fix_connected_components(
|
||||
distances, graph, n_connected_components, labels, metric="precomputed"
|
||||
)
|
||||
|
||||
n_connected_components, labels = connected_components(graph)
|
||||
assert n_connected_components == 1
|
||||
|
||||
# but it does not work with precomputed neighbors graph
|
||||
with pytest.raises(RuntimeError, match="does not work with a sparse"):
|
||||
_fix_connected_components(
|
||||
graph, graph, n_connected_components, labels, metric="precomputed"
|
||||
)
|
||||
|
||||
|
||||
def test_fix_connected_components_wrong_mode():
|
||||
# Test that the an error is raised if the mode string is incorrect.
|
||||
X = np.array([0, 1, 2, 5, 6, 7])[:, None]
|
||||
graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
|
||||
n_connected_components, labels = connected_components(graph)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown mode"):
|
||||
graph = _fix_connected_components(
|
||||
X, graph, n_connected_components, labels, mode="foo"
|
||||
)
|
||||
|
||||
|
||||
def test_fix_connected_components_connectivity_mode():
|
||||
# Test that the connectivity mode fill new connections with ones.
|
||||
X = np.array([0, 1, 6, 7])[:, None]
|
||||
graph = kneighbors_graph(X, n_neighbors=1, mode="connectivity")
|
||||
n_connected_components, labels = connected_components(graph)
|
||||
graph = _fix_connected_components(
|
||||
X, graph, n_connected_components, labels, mode="connectivity"
|
||||
)
|
||||
assert np.all(graph.data == 1)
|
||||
|
||||
|
||||
def test_fix_connected_components_distance_mode():
|
||||
# Test that the distance mode does not fill new connections with ones.
|
||||
X = np.array([0, 1, 6, 7])[:, None]
|
||||
graph = kneighbors_graph(X, n_neighbors=1, mode="distance")
|
||||
assert np.all(graph.data == 1)
|
||||
|
||||
n_connected_components, labels = connected_components(graph)
|
||||
graph = _fix_connected_components(
|
||||
X, graph, n_connected_components, labels, mode="distance"
|
||||
)
|
||||
assert not np.all(graph.data == 1)
|
||||
Reference in New Issue
Block a user