Files
2025-08-01 04:33:03 -04:00

155 lines
5.0 KiB
Python

# -*- coding: utf-8 -*-
"""
Arg validation for auto-arima calls. This allows us to test validation more
directly without having to fit numerous combinations of models.
"""
import numpy as np
import warnings
from sklearn import metrics
from pmdarima.warnings import ModelFitWarning
# The valid information criteria
VALID_CRITERIA = {'aic', 'aicc', 'bic', 'hqic', 'oob'}
def auto_intercept(with_intercept, default):
"""A more concise way to handle the default behavior of with_intercept"""
if with_intercept == "auto":
return default
return with_intercept
def check_information_criterion(information_criterion, out_of_sample_size):
"""Check whether the information criterion is valid"""
if information_criterion not in VALID_CRITERIA:
raise ValueError('auto_arima not defined for information_criteria=%s. '
'Valid information criteria include: %r'
% (information_criterion, VALID_CRITERIA))
# check on information criterion and out_of_sample size
if information_criterion == 'oob' and out_of_sample_size == 0:
information_criterion = 'aic'
warnings.warn('information_criterion cannot be \'oob\' with '
'out_of_sample_size = 0. '
'Falling back to information criterion = aic.')
return information_criterion
def check_kwargs(kwargs):
"""Return kwargs or an empty dict.
We often pass named kwargs (like `sarimax_kwargs`) as None by default. This
is to avoid a mutable default, which can bite you in unexpected ways. This
will return a kwarg-compatible value.
"""
if kwargs:
return kwargs
return {}
def check_m(m, seasonal):
"""Check the value of M (seasonal periodicity)"""
if (m < 1 and seasonal) or m < 0:
raise ValueError('m must be a positive integer (> 0)')
if not seasonal:
# default m is 1, so if it's the default, don't warn
if m > 1:
warnings.warn("m (%i) set for non-seasonal fit. Setting to 0" % m)
m = 0
return m
def check_n_jobs(stepwise, n_jobs):
"""Potentially update the n_jobs parameter
We can't run in parallel with the stepwise algorithm. This checks
``n_jobs`` w.r.t. stepwise and will warn.
"""
if stepwise and n_jobs != 1:
n_jobs = 1
warnings.warn('stepwise model cannot be fit in parallel (n_jobs=%i). '
'Falling back to stepwise parameter search.' % n_jobs)
return n_jobs
def check_start_max_values(st, mx, argname):
"""Ensure starting points and ending points are valid"""
if mx is None:
mx = np.inf
if st is None:
raise ValueError("start_%s cannot be None" % argname)
if st < 0:
raise ValueError("start_%s must be positive" % argname)
if mx < st:
raise ValueError("max_%s must be >= start_%s" % (argname, argname))
return st, mx
def check_trace(trace):
"""Check the value of trace"""
if trace is None:
return 0
if isinstance(trace, (int, bool)):
return int(trace)
# otherwise just be truthy with it
if trace:
return 1
return 0
def get_scoring_metric(metric):
"""Get a scoring metric by name, or passthrough a callable
Parameters
----------
metric : str or callable
A name of a scoring metric, or a custom callable function. If it is a
callable, it must adhere to the signature::
def func(y_true, y_pred)
Note that the ARIMA model selection seeks to MINIMIZE the score, and it
is up to the user to ensure that scoring methods that return maximizing
criteria (i.e., ``r2_score``) are wrapped in a function that will
return the negative value of the score.
"""
if isinstance(metric, str):
# XXX: legacy support, remap mse/mae to their long versions
if metric == "mse":
return metrics.mean_squared_error
if metric == "mae":
return metrics.mean_absolute_error
try:
return getattr(metrics, metric)
except AttributeError:
raise ValueError("'%s' is not a valid scoring method." % metric)
if not callable(metric):
raise TypeError("`metric` must be a valid scoring method, or a "
"callable, but got type=%s" % type(metric))
# TODO: warn for potentially invalid signature?
return metric
def warn_for_D(d, D):
"""Warn for large values of D"""
if D >= 2:
warnings.warn("Having more than one seasonal differences is "
"not recommended. Please consider using only one "
"seasonal difference.", ModelFitWarning)
# if D is -1, this will be off, so we include the OR
# TODO: @FutureTayTay.. how can D be -1?
elif D + d > 2 or d > 2:
warnings.warn("Having 3 or more differencing operations is not "
"recommended. Please consider reducing the total "
"number of differences.", ModelFitWarning)