reconnect moved files to git repo
This commit is contained in:
@ -0,0 +1,184 @@
|
||||
"""Modified Olivetti faces dataset.
|
||||
|
||||
The original database was available from (now defunct)
|
||||
|
||||
https://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html
|
||||
|
||||
The version retrieved here comes in MATLAB format from the personal
|
||||
web page of Sam Roweis:
|
||||
|
||||
https://cs.nyu.edu/~roweis/
|
||||
"""
|
||||
|
||||
# Copyright (c) 2011 David Warde-Farley <wardefar at iro dot umontreal dot ca>
|
||||
# License: BSD 3 clause
|
||||
|
||||
from numbers import Integral, Real
|
||||
from os import PathLike, makedirs, remove
|
||||
from os.path import exists
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from scipy.io import loadmat
|
||||
|
||||
from ..utils import Bunch, check_random_state
|
||||
from ..utils._param_validation import Interval, validate_params
|
||||
from . import get_data_home
|
||||
from ._base import RemoteFileMetadata, _fetch_remote, _pkl_filepath, load_descr
|
||||
|
||||
# The original data can be found at:
|
||||
# https://cs.nyu.edu/~roweis/data/olivettifaces.mat
|
||||
FACES = RemoteFileMetadata(
|
||||
filename="olivettifaces.mat",
|
||||
url="https://ndownloader.figshare.com/files/5976027",
|
||||
checksum="b612fb967f2dc77c9c62d3e1266e0c73d5fca46a4b8906c18e454d41af987794",
|
||||
)
|
||||
|
||||
|
||||
@validate_params(
|
||||
{
|
||||
"data_home": [str, PathLike, None],
|
||||
"shuffle": ["boolean"],
|
||||
"random_state": ["random_state"],
|
||||
"download_if_missing": ["boolean"],
|
||||
"return_X_y": ["boolean"],
|
||||
"n_retries": [Interval(Integral, 1, None, closed="left")],
|
||||
"delay": [Interval(Real, 0.0, None, closed="neither")],
|
||||
},
|
||||
prefer_skip_nested_validation=True,
|
||||
)
|
||||
def fetch_olivetti_faces(
|
||||
*,
|
||||
data_home=None,
|
||||
shuffle=False,
|
||||
random_state=0,
|
||||
download_if_missing=True,
|
||||
return_X_y=False,
|
||||
n_retries=3,
|
||||
delay=1.0,
|
||||
):
|
||||
"""Load the Olivetti faces data-set from AT&T (classification).
|
||||
|
||||
Download it if necessary.
|
||||
|
||||
================= =====================
|
||||
Classes 40
|
||||
Samples total 400
|
||||
Dimensionality 4096
|
||||
Features real, between 0 and 1
|
||||
================= =====================
|
||||
|
||||
Read more in the :ref:`User Guide <olivetti_faces_dataset>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_home : str or path-like, default=None
|
||||
Specify another download and cache folder for the datasets. By default
|
||||
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
|
||||
|
||||
shuffle : bool, default=False
|
||||
If True the order of the dataset is shuffled to avoid having
|
||||
images of the same person grouped.
|
||||
|
||||
random_state : int, RandomState instance or None, default=0
|
||||
Determines random number generation for dataset shuffling. Pass an int
|
||||
for reproducible output across multiple function calls.
|
||||
See :term:`Glossary <random_state>`.
|
||||
|
||||
download_if_missing : bool, default=True
|
||||
If False, raise an OSError if the data is not locally available
|
||||
instead of trying to download the data from the source site.
|
||||
|
||||
return_X_y : bool, default=False
|
||||
If True, returns `(data, target)` instead of a `Bunch` object. See
|
||||
below for more information about the `data` and `target` object.
|
||||
|
||||
.. versionadded:: 0.22
|
||||
|
||||
n_retries : int, default=3
|
||||
Number of retries when HTTP errors are encountered.
|
||||
|
||||
.. versionadded:: 1.5
|
||||
|
||||
delay : float, default=1.0
|
||||
Number of seconds between retries.
|
||||
|
||||
.. versionadded:: 1.5
|
||||
|
||||
Returns
|
||||
-------
|
||||
data : :class:`~sklearn.utils.Bunch`
|
||||
Dictionary-like object, with the following attributes.
|
||||
|
||||
data: ndarray, shape (400, 4096)
|
||||
Each row corresponds to a ravelled
|
||||
face image of original size 64 x 64 pixels.
|
||||
images : ndarray, shape (400, 64, 64)
|
||||
Each row is a face image
|
||||
corresponding to one of the 40 subjects of the dataset.
|
||||
target : ndarray, shape (400,)
|
||||
Labels associated to each face image.
|
||||
Those labels are ranging from 0-39 and correspond to the
|
||||
Subject IDs.
|
||||
DESCR : str
|
||||
Description of the modified Olivetti Faces Dataset.
|
||||
|
||||
(data, target) : tuple if `return_X_y=True`
|
||||
Tuple with the `data` and `target` objects described above.
|
||||
|
||||
.. versionadded:: 0.22
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.datasets import fetch_olivetti_faces
|
||||
>>> olivetti_faces = fetch_olivetti_faces()
|
||||
>>> olivetti_faces.data.shape
|
||||
(400, 4096)
|
||||
>>> olivetti_faces.target.shape
|
||||
(400,)
|
||||
>>> olivetti_faces.images.shape
|
||||
(400, 64, 64)
|
||||
"""
|
||||
data_home = get_data_home(data_home=data_home)
|
||||
if not exists(data_home):
|
||||
makedirs(data_home)
|
||||
filepath = _pkl_filepath(data_home, "olivetti.pkz")
|
||||
if not exists(filepath):
|
||||
if not download_if_missing:
|
||||
raise OSError("Data not found and `download_if_missing` is False")
|
||||
|
||||
print("downloading Olivetti faces from %s to %s" % (FACES.url, data_home))
|
||||
mat_path = _fetch_remote(
|
||||
FACES, dirname=data_home, n_retries=n_retries, delay=delay
|
||||
)
|
||||
mfile = loadmat(file_name=mat_path)
|
||||
# delete raw .mat data
|
||||
remove(mat_path)
|
||||
|
||||
faces = mfile["faces"].T.copy()
|
||||
joblib.dump(faces, filepath, compress=6)
|
||||
del mfile
|
||||
else:
|
||||
faces = joblib.load(filepath)
|
||||
|
||||
# We want floating point data, but float32 is enough (there is only
|
||||
# one byte of precision in the original uint8s anyway)
|
||||
faces = np.float32(faces)
|
||||
faces = faces - faces.min()
|
||||
faces /= faces.max()
|
||||
faces = faces.reshape((400, 64, 64)).transpose(0, 2, 1)
|
||||
# 10 images per class, 400 images total, each class is contiguous.
|
||||
target = np.array([i // 10 for i in range(400)])
|
||||
if shuffle:
|
||||
random_state = check_random_state(random_state)
|
||||
order = random_state.permutation(len(faces))
|
||||
faces = faces[order]
|
||||
target = target[order]
|
||||
faces_vectorized = faces.reshape(len(faces), -1)
|
||||
|
||||
fdescr = load_descr("olivetti_faces.rst")
|
||||
|
||||
if return_X_y:
|
||||
return faces_vectorized, target
|
||||
|
||||
return Bunch(data=faces_vectorized, images=faces, target=target, DESCR=fdescr)
|
||||
Reference in New Issue
Block a user