some new features
This commit is contained in:
@ -0,0 +1,229 @@
|
||||
"""This test for the LFW require medium-size data downloading and processing
|
||||
|
||||
If the data has not been already downloaded by running the examples,
|
||||
the tests won't run (skipped).
|
||||
|
||||
If the test are run, the first execution will be long (typically a bit
|
||||
more than a couple of minutes) but as the dataset loader is leveraging
|
||||
joblib, successive runs will be fast (less than 200ms).
|
||||
"""
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.datasets import fetch_lfw_pairs, fetch_lfw_people
|
||||
from sklearn.datasets.tests.test_common import check_return_X_y
|
||||
from sklearn.utils._testing import assert_array_equal
|
||||
|
||||
FAKE_NAMES = [
|
||||
"Abdelatif_Smith",
|
||||
"Abhati_Kepler",
|
||||
"Camara_Alvaro",
|
||||
"Chen_Dupont",
|
||||
"John_Lee",
|
||||
"Lin_Bauman",
|
||||
"Onur_Lopez",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_empty_data_home(tmp_path_factory):
|
||||
data_dir = tmp_path_factory.mktemp("scikit_learn_empty_test")
|
||||
|
||||
yield data_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_data_home(tmp_path_factory):
|
||||
"""Test fixture run once and common to all tests of this module"""
|
||||
Image = pytest.importorskip("PIL.Image")
|
||||
|
||||
data_dir = tmp_path_factory.mktemp("scikit_learn_lfw_test")
|
||||
lfw_home = data_dir / "lfw_home"
|
||||
lfw_home.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
random_state = random.Random(42)
|
||||
np_rng = np.random.RandomState(42)
|
||||
|
||||
# generate some random jpeg files for each person
|
||||
counts = {}
|
||||
for name in FAKE_NAMES:
|
||||
folder_name = lfw_home / "lfw_funneled" / name
|
||||
folder_name.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
n_faces = np_rng.randint(1, 5)
|
||||
counts[name] = n_faces
|
||||
for i in range(n_faces):
|
||||
file_path = folder_name / (name + "_%04d.jpg" % i)
|
||||
uniface = np_rng.randint(0, 255, size=(250, 250, 3))
|
||||
img = Image.fromarray(uniface.astype(np.uint8))
|
||||
img.save(file_path)
|
||||
|
||||
# add some random file pollution to test robustness
|
||||
(lfw_home / "lfw_funneled" / ".test.swp").write_bytes(
|
||||
b"Text file to be ignored by the dataset loader."
|
||||
)
|
||||
|
||||
# generate some pairing metadata files using the same format as LFW
|
||||
with open(lfw_home / "pairsDevTrain.txt", "wb") as f:
|
||||
f.write(b"10\n")
|
||||
more_than_two = [name for name, count in counts.items() if count >= 2]
|
||||
for i in range(5):
|
||||
name = random_state.choice(more_than_two)
|
||||
first, second = random_state.sample(range(counts[name]), 2)
|
||||
f.write(("%s\t%d\t%d\n" % (name, first, second)).encode())
|
||||
|
||||
for i in range(5):
|
||||
first_name, second_name = random_state.sample(FAKE_NAMES, 2)
|
||||
first_index = np_rng.choice(np.arange(counts[first_name]))
|
||||
second_index = np_rng.choice(np.arange(counts[second_name]))
|
||||
f.write(
|
||||
(
|
||||
"%s\t%d\t%s\t%d\n"
|
||||
% (first_name, first_index, second_name, second_index)
|
||||
).encode()
|
||||
)
|
||||
|
||||
(lfw_home / "pairsDevTest.txt").write_bytes(
|
||||
b"Fake place holder that won't be tested"
|
||||
)
|
||||
(lfw_home / "pairs.txt").write_bytes(b"Fake place holder that won't be tested")
|
||||
|
||||
yield data_dir
|
||||
|
||||
|
||||
def test_load_empty_lfw_people(mock_empty_data_home):
|
||||
with pytest.raises(OSError):
|
||||
fetch_lfw_people(data_home=mock_empty_data_home, download_if_missing=False)
|
||||
|
||||
|
||||
def test_load_fake_lfw_people(mock_data_home):
|
||||
lfw_people = fetch_lfw_people(
|
||||
data_home=mock_data_home, min_faces_per_person=3, download_if_missing=False
|
||||
)
|
||||
|
||||
# The data is croped around the center as a rectangular bounding box
|
||||
# around the face. Colors are converted to gray levels:
|
||||
assert lfw_people.images.shape == (10, 62, 47)
|
||||
assert lfw_people.data.shape == (10, 2914)
|
||||
|
||||
# the target is array of person integer ids
|
||||
assert_array_equal(lfw_people.target, [2, 0, 1, 0, 2, 0, 2, 1, 1, 2])
|
||||
|
||||
# names of the persons can be found using the target_names array
|
||||
expected_classes = ["Abdelatif Smith", "Abhati Kepler", "Onur Lopez"]
|
||||
assert_array_equal(lfw_people.target_names, expected_classes)
|
||||
|
||||
# It is possible to ask for the original data without any croping or color
|
||||
# conversion and not limit on the number of picture per person
|
||||
lfw_people = fetch_lfw_people(
|
||||
data_home=mock_data_home,
|
||||
resize=None,
|
||||
slice_=None,
|
||||
color=True,
|
||||
download_if_missing=False,
|
||||
)
|
||||
assert lfw_people.images.shape == (17, 250, 250, 3)
|
||||
assert lfw_people.DESCR.startswith(".. _labeled_faces_in_the_wild_dataset:")
|
||||
|
||||
# the ids and class names are the same as previously
|
||||
assert_array_equal(
|
||||
lfw_people.target, [0, 0, 1, 6, 5, 6, 3, 6, 0, 3, 6, 1, 2, 4, 5, 1, 2]
|
||||
)
|
||||
assert_array_equal(
|
||||
lfw_people.target_names,
|
||||
[
|
||||
"Abdelatif Smith",
|
||||
"Abhati Kepler",
|
||||
"Camara Alvaro",
|
||||
"Chen Dupont",
|
||||
"John Lee",
|
||||
"Lin Bauman",
|
||||
"Onur Lopez",
|
||||
],
|
||||
)
|
||||
|
||||
# test return_X_y option
|
||||
fetch_func = partial(
|
||||
fetch_lfw_people,
|
||||
data_home=mock_data_home,
|
||||
resize=None,
|
||||
slice_=None,
|
||||
color=True,
|
||||
download_if_missing=False,
|
||||
)
|
||||
check_return_X_y(lfw_people, fetch_func)
|
||||
|
||||
|
||||
def test_load_fake_lfw_people_too_restrictive(mock_data_home):
|
||||
with pytest.raises(ValueError):
|
||||
fetch_lfw_people(
|
||||
data_home=mock_data_home,
|
||||
min_faces_per_person=100,
|
||||
download_if_missing=False,
|
||||
)
|
||||
|
||||
|
||||
def test_load_empty_lfw_pairs(mock_empty_data_home):
|
||||
with pytest.raises(OSError):
|
||||
fetch_lfw_pairs(data_home=mock_empty_data_home, download_if_missing=False)
|
||||
|
||||
|
||||
def test_load_fake_lfw_pairs(mock_data_home):
|
||||
lfw_pairs_train = fetch_lfw_pairs(
|
||||
data_home=mock_data_home, download_if_missing=False
|
||||
)
|
||||
|
||||
# The data is croped around the center as a rectangular bounding box
|
||||
# around the face. Colors are converted to gray levels:
|
||||
assert lfw_pairs_train.pairs.shape == (10, 2, 62, 47)
|
||||
|
||||
# the target is whether the person is the same or not
|
||||
assert_array_equal(lfw_pairs_train.target, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
|
||||
|
||||
# names of the persons can be found using the target_names array
|
||||
expected_classes = ["Different persons", "Same person"]
|
||||
assert_array_equal(lfw_pairs_train.target_names, expected_classes)
|
||||
|
||||
# It is possible to ask for the original data without any croping or color
|
||||
# conversion
|
||||
lfw_pairs_train = fetch_lfw_pairs(
|
||||
data_home=mock_data_home,
|
||||
resize=None,
|
||||
slice_=None,
|
||||
color=True,
|
||||
download_if_missing=False,
|
||||
)
|
||||
assert lfw_pairs_train.pairs.shape == (10, 2, 250, 250, 3)
|
||||
|
||||
# the ids and class names are the same as previously
|
||||
assert_array_equal(lfw_pairs_train.target, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
|
||||
assert_array_equal(lfw_pairs_train.target_names, expected_classes)
|
||||
|
||||
assert lfw_pairs_train.DESCR.startswith(".. _labeled_faces_in_the_wild_dataset:")
|
||||
|
||||
|
||||
def test_fetch_lfw_people_internal_cropping(mock_data_home):
|
||||
"""Check that we properly crop the images.
|
||||
|
||||
Non-regression test for:
|
||||
https://github.com/scikit-learn/scikit-learn/issues/24942
|
||||
"""
|
||||
# If cropping was not done properly and we don't resize the images, the images would
|
||||
# have their original size (250x250) and the image would not fit in the NumPy array
|
||||
# pre-allocated based on `slice_` parameter.
|
||||
slice_ = (slice(70, 195), slice(78, 172))
|
||||
lfw = fetch_lfw_people(
|
||||
data_home=mock_data_home,
|
||||
min_faces_per_person=3,
|
||||
download_if_missing=False,
|
||||
resize=None,
|
||||
slice_=slice_,
|
||||
)
|
||||
assert lfw.images[0].shape == (
|
||||
slice_[0].stop - slice_[0].start,
|
||||
slice_[1].stop - slice_[1].start,
|
||||
)
|
||||
Reference in New Issue
Block a user