some new features
This commit is contained in:
@ -0,0 +1,187 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.base import clone
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import (
|
||||
MaxAbsScaler,
|
||||
MinMaxScaler,
|
||||
PowerTransformer,
|
||||
QuantileTransformer,
|
||||
RobustScaler,
|
||||
StandardScaler,
|
||||
maxabs_scale,
|
||||
minmax_scale,
|
||||
power_transform,
|
||||
quantile_transform,
|
||||
robust_scale,
|
||||
scale,
|
||||
)
|
||||
from sklearn.utils._testing import assert_allclose, assert_array_equal
|
||||
from sklearn.utils.fixes import (
|
||||
BSR_CONTAINERS,
|
||||
COO_CONTAINERS,
|
||||
CSC_CONTAINERS,
|
||||
CSR_CONTAINERS,
|
||||
DIA_CONTAINERS,
|
||||
DOK_CONTAINERS,
|
||||
LIL_CONTAINERS,
|
||||
)
|
||||
|
||||
iris = load_iris()
|
||||
|
||||
|
||||
def _get_valid_samples_by_column(X, col):
|
||||
"""Get non NaN samples in column of X"""
|
||||
return X[:, [col]][~np.isnan(X[:, col])]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"est, func, support_sparse, strictly_positive, omit_kwargs",
|
||||
[
|
||||
(MaxAbsScaler(), maxabs_scale, True, False, []),
|
||||
(MinMaxScaler(), minmax_scale, False, False, ["clip"]),
|
||||
(StandardScaler(), scale, False, False, []),
|
||||
(StandardScaler(with_mean=False), scale, True, False, []),
|
||||
(PowerTransformer("yeo-johnson"), power_transform, False, False, []),
|
||||
(PowerTransformer("box-cox"), power_transform, False, True, []),
|
||||
(QuantileTransformer(n_quantiles=10), quantile_transform, True, False, []),
|
||||
(RobustScaler(), robust_scale, False, False, []),
|
||||
(RobustScaler(with_centering=False), robust_scale, True, False, []),
|
||||
],
|
||||
)
|
||||
def test_missing_value_handling(
|
||||
est, func, support_sparse, strictly_positive, omit_kwargs
|
||||
):
|
||||
# check that the preprocessing method let pass nan
|
||||
rng = np.random.RandomState(42)
|
||||
X = iris.data.copy()
|
||||
n_missing = 50
|
||||
X[
|
||||
rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)
|
||||
] = np.nan
|
||||
if strictly_positive:
|
||||
X += np.nanmin(X) + 0.1
|
||||
X_train, X_test = train_test_split(X, random_state=1)
|
||||
# sanity check
|
||||
assert not np.all(np.isnan(X_train), axis=0).any()
|
||||
assert np.any(np.isnan(X_train), axis=0).all()
|
||||
assert np.any(np.isnan(X_test), axis=0).all()
|
||||
X_test[:, 0] = np.nan # make sure this boundary case is tested
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
Xt = est.fit(X_train).transform(X_test)
|
||||
# ensure no warnings are raised
|
||||
# missing values should still be missing, and only them
|
||||
assert_array_equal(np.isnan(Xt), np.isnan(X_test))
|
||||
|
||||
# check that the function leads to the same results as the class
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
Xt_class = est.transform(X_train)
|
||||
kwargs = est.get_params()
|
||||
# remove the parameters which should be omitted because they
|
||||
# are not defined in the counterpart function of the preprocessing class
|
||||
for kwarg in omit_kwargs:
|
||||
_ = kwargs.pop(kwarg)
|
||||
Xt_func = func(X_train, **kwargs)
|
||||
assert_array_equal(np.isnan(Xt_func), np.isnan(Xt_class))
|
||||
assert_allclose(Xt_func[~np.isnan(Xt_func)], Xt_class[~np.isnan(Xt_class)])
|
||||
|
||||
# check that the inverse transform keep NaN
|
||||
Xt_inv = est.inverse_transform(Xt)
|
||||
assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test))
|
||||
# FIXME: we can introduce equal_nan=True in recent version of numpy.
|
||||
# For the moment which just check that non-NaN values are almost equal.
|
||||
assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)])
|
||||
|
||||
for i in range(X.shape[1]):
|
||||
# train only on non-NaN
|
||||
est.fit(_get_valid_samples_by_column(X_train, i))
|
||||
# check transforming with NaN works even when training without NaN
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
Xt_col = est.transform(X_test[:, [i]])
|
||||
assert_allclose(Xt_col, Xt[:, [i]])
|
||||
# check non-NaN is handled as before - the 1st column is all nan
|
||||
if not np.isnan(X_test[:, i]).all():
|
||||
Xt_col_nonan = est.transform(_get_valid_samples_by_column(X_test, i))
|
||||
assert_array_equal(Xt_col_nonan, Xt_col[~np.isnan(Xt_col.squeeze())])
|
||||
|
||||
if support_sparse:
|
||||
est_dense = clone(est)
|
||||
est_sparse = clone(est)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
Xt_dense = est_dense.fit(X_train).transform(X_test)
|
||||
Xt_inv_dense = est_dense.inverse_transform(Xt_dense)
|
||||
|
||||
for sparse_container in (
|
||||
BSR_CONTAINERS
|
||||
+ COO_CONTAINERS
|
||||
+ CSC_CONTAINERS
|
||||
+ CSR_CONTAINERS
|
||||
+ DIA_CONTAINERS
|
||||
+ DOK_CONTAINERS
|
||||
+ LIL_CONTAINERS
|
||||
):
|
||||
# check that the dense and sparse inputs lead to the same results
|
||||
# precompute the matrix to avoid catching side warnings
|
||||
X_train_sp = sparse_container(X_train)
|
||||
X_test_sp = sparse_container(X_test)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", PendingDeprecationWarning)
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp)
|
||||
|
||||
assert_allclose(Xt_sp.toarray(), Xt_dense)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", PendingDeprecationWarning)
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
Xt_inv_sp = est_sparse.inverse_transform(Xt_sp)
|
||||
|
||||
assert_allclose(Xt_inv_sp.toarray(), Xt_inv_dense)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"est, func",
|
||||
[
|
||||
(MaxAbsScaler(), maxabs_scale),
|
||||
(MinMaxScaler(), minmax_scale),
|
||||
(StandardScaler(), scale),
|
||||
(StandardScaler(with_mean=False), scale),
|
||||
(PowerTransformer("yeo-johnson"), power_transform),
|
||||
(
|
||||
PowerTransformer("box-cox"),
|
||||
power_transform,
|
||||
),
|
||||
(QuantileTransformer(n_quantiles=3), quantile_transform),
|
||||
(RobustScaler(), robust_scale),
|
||||
(RobustScaler(with_centering=False), robust_scale),
|
||||
],
|
||||
)
|
||||
def test_missing_value_pandas_na_support(est, func):
|
||||
# Test pandas IntegerArray with pd.NA
|
||||
pd = pytest.importorskip("pandas")
|
||||
|
||||
X = np.array(
|
||||
[
|
||||
[1, 2, 3, np.nan, np.nan, 4, 5, 1],
|
||||
[np.nan, np.nan, 8, 4, 6, np.nan, np.nan, 8],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8],
|
||||
]
|
||||
).T
|
||||
|
||||
# Creates dataframe with IntegerArrays with pd.NA
|
||||
X_df = pd.DataFrame(X, dtype="Int16", columns=["a", "b", "c"])
|
||||
X_df["c"] = X_df["c"].astype("int")
|
||||
|
||||
X_trans = est.fit_transform(X)
|
||||
X_df_trans = est.fit_transform(X_df)
|
||||
|
||||
assert_allclose(X_trans, X_df_trans)
|
||||
Reference in New Issue
Block a user