Files
Time-Series-Analysis/venv/lib/python3.11/site-packages/pmdarima/compat/sklearn.py
2025-08-01 04:33:03 -04:00

99 lines
2.7 KiB
Python

# -*- coding: utf-8 -*-
#
# Author: Charles Drotar <drotarcharles@gmail.com>
#
# Patch backend for sklearn
from packaging.version import Version
import sklearn
from sklearn.exceptions import NotFittedError
__all__ = [
'check_is_fitted',
'if_delegate_has_method',
'safe_indexing',
]
def check_is_fitted(estimator, attributes):
"""Ensure the model has been fitted
Typically called at the beginning of an operation on a model that requires
having been fit. Raises a ``NotFittedError`` if the model has not been
fit.
This is an adaptation of scikit-learn's ``check_is_fitted``, which has been
changed recently in a way that is no longer compatible with our package.
Parameters
----------
estimator : estimator instance,
The estimator that will be checked to see if it is fitted.
attributes : str or iterable
The attributes to check for
"""
if isinstance(attributes, str):
attributes = [attributes]
if not hasattr(attributes, "__iter__"):
raise TypeError("attributes must be a string or iterable")
for attr in attributes:
if hasattr(estimator, attr):
return
raise NotFittedError("Model has not been fit!")
def safe_indexing(X, indices):
"""Slice an array or dataframe. This is deprecated in sklearn"""
if hasattr(X, 'iloc'):
return X.iloc[indices]
# numpy:
# TODO: this does not currently support axis 1
if hasattr(X, 'ndim') and X.ndim == 2:
return X[indices, :]
# list or 1d array
return X[indices]
def _estimator_has(attr):
"""Checks if the model has a given attribute.
Meant to be used along with `sklearn.utils.metaestimators.available_if`
Parameters
----------
attr : str
The attribute to check the calling object for
Returns
-------
fn : callable
A function that will either raise an `AttributeError` if the attribute
does not exist, or True if it does.
"""
def check(self):
# raise original `AttributeError` if `attr` does not exist
getattr(self, attr)
return True
return check
def if_delegate_has_method(attr):
"""Compat method to replace `sklearn.utils.metaestimators.if_delegate_has`
Older versions (< 1.0.0) of sklearn support it, but newer versions use
`available_if` instead.
References
----------
.. [1] https://git.io/JzKiv
.. [2] https://git.io/JzKiJ
"""
if Version(sklearn.__version__) < Version("1.0.0"):
from sklearn.utils.metaestimators import if_delegate_has_method
return if_delegate_has_method(attr)
else:
from sklearn.utils.metaestimators import available_if
return available_if(_estimator_has(attr))