reconnect moved files to git repo
This commit is contained in:
154
venv/lib/python3.11/site-packages/pmdarima/arima/_validation.py
Normal file
154
venv/lib/python3.11/site-packages/pmdarima/arima/_validation.py
Normal file
@ -0,0 +1,154 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Arg validation for auto-arima calls. This allows us to test validation more
|
||||
directly without having to fit numerous combinations of models.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import warnings
|
||||
from sklearn import metrics
|
||||
|
||||
from pmdarima.warnings import ModelFitWarning
|
||||
|
||||
# The valid information criteria
|
||||
VALID_CRITERIA = {'aic', 'aicc', 'bic', 'hqic', 'oob'}
|
||||
|
||||
|
||||
def auto_intercept(with_intercept, default):
|
||||
"""A more concise way to handle the default behavior of with_intercept"""
|
||||
if with_intercept == "auto":
|
||||
return default
|
||||
return with_intercept
|
||||
|
||||
|
||||
def check_information_criterion(information_criterion, out_of_sample_size):
|
||||
"""Check whether the information criterion is valid"""
|
||||
if information_criterion not in VALID_CRITERIA:
|
||||
raise ValueError('auto_arima not defined for information_criteria=%s. '
|
||||
'Valid information criteria include: %r'
|
||||
% (information_criterion, VALID_CRITERIA))
|
||||
|
||||
# check on information criterion and out_of_sample size
|
||||
if information_criterion == 'oob' and out_of_sample_size == 0:
|
||||
information_criterion = 'aic'
|
||||
warnings.warn('information_criterion cannot be \'oob\' with '
|
||||
'out_of_sample_size = 0. '
|
||||
'Falling back to information criterion = aic.')
|
||||
|
||||
return information_criterion
|
||||
|
||||
|
||||
def check_kwargs(kwargs):
|
||||
"""Return kwargs or an empty dict.
|
||||
|
||||
We often pass named kwargs (like `sarimax_kwargs`) as None by default. This
|
||||
is to avoid a mutable default, which can bite you in unexpected ways. This
|
||||
will return a kwarg-compatible value.
|
||||
"""
|
||||
if kwargs:
|
||||
return kwargs
|
||||
return {}
|
||||
|
||||
|
||||
def check_m(m, seasonal):
|
||||
"""Check the value of M (seasonal periodicity)"""
|
||||
if (m < 1 and seasonal) or m < 0:
|
||||
raise ValueError('m must be a positive integer (> 0)')
|
||||
|
||||
if not seasonal:
|
||||
# default m is 1, so if it's the default, don't warn
|
||||
if m > 1:
|
||||
warnings.warn("m (%i) set for non-seasonal fit. Setting to 0" % m)
|
||||
m = 0
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def check_n_jobs(stepwise, n_jobs):
|
||||
"""Potentially update the n_jobs parameter
|
||||
|
||||
We can't run in parallel with the stepwise algorithm. This checks
|
||||
``n_jobs`` w.r.t. stepwise and will warn.
|
||||
"""
|
||||
if stepwise and n_jobs != 1:
|
||||
n_jobs = 1
|
||||
warnings.warn('stepwise model cannot be fit in parallel (n_jobs=%i). '
|
||||
'Falling back to stepwise parameter search.' % n_jobs)
|
||||
return n_jobs
|
||||
|
||||
|
||||
def check_start_max_values(st, mx, argname):
|
||||
"""Ensure starting points and ending points are valid"""
|
||||
if mx is None:
|
||||
mx = np.inf
|
||||
if st is None:
|
||||
raise ValueError("start_%s cannot be None" % argname)
|
||||
if st < 0:
|
||||
raise ValueError("start_%s must be positive" % argname)
|
||||
if mx < st:
|
||||
raise ValueError("max_%s must be >= start_%s" % (argname, argname))
|
||||
return st, mx
|
||||
|
||||
|
||||
def check_trace(trace):
|
||||
"""Check the value of trace"""
|
||||
if trace is None:
|
||||
return 0
|
||||
if isinstance(trace, (int, bool)):
|
||||
return int(trace)
|
||||
# otherwise just be truthy with it
|
||||
if trace:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def get_scoring_metric(metric):
|
||||
"""Get a scoring metric by name, or passthrough a callable
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str or callable
|
||||
A name of a scoring metric, or a custom callable function. If it is a
|
||||
callable, it must adhere to the signature::
|
||||
|
||||
def func(y_true, y_pred)
|
||||
|
||||
Note that the ARIMA model selection seeks to MINIMIZE the score, and it
|
||||
is up to the user to ensure that scoring methods that return maximizing
|
||||
criteria (i.e., ``r2_score``) are wrapped in a function that will
|
||||
return the negative value of the score.
|
||||
"""
|
||||
if isinstance(metric, str):
|
||||
|
||||
# XXX: legacy support, remap mse/mae to their long versions
|
||||
if metric == "mse":
|
||||
return metrics.mean_squared_error
|
||||
if metric == "mae":
|
||||
return metrics.mean_absolute_error
|
||||
|
||||
try:
|
||||
return getattr(metrics, metric)
|
||||
except AttributeError:
|
||||
raise ValueError("'%s' is not a valid scoring method." % metric)
|
||||
|
||||
if not callable(metric):
|
||||
raise TypeError("`metric` must be a valid scoring method, or a "
|
||||
"callable, but got type=%s" % type(metric))
|
||||
|
||||
# TODO: warn for potentially invalid signature?
|
||||
return metric
|
||||
|
||||
|
||||
def warn_for_D(d, D):
|
||||
"""Warn for large values of D"""
|
||||
if D >= 2:
|
||||
warnings.warn("Having more than one seasonal differences is "
|
||||
"not recommended. Please consider using only one "
|
||||
"seasonal difference.", ModelFitWarning)
|
||||
# if D is -1, this will be off, so we include the OR
|
||||
# TODO: @FutureTayTay.. how can D be -1?
|
||||
elif D + d > 2 or d > 2:
|
||||
warnings.warn("Having 3 or more differencing operations is not "
|
||||
"recommended. Please consider reducing the total "
|
||||
"number of differences.", ModelFitWarning)
|
||||
Reference in New Issue
Block a user