155 lines
5.0 KiB
Python
155 lines
5.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Arg validation for auto-arima calls. This allows us to test validation more
|
|
directly without having to fit numerous combinations of models.
|
|
"""
|
|
|
|
import numpy as np
|
|
import warnings
|
|
from sklearn import metrics
|
|
|
|
from pmdarima.warnings import ModelFitWarning
|
|
|
|
# The valid information criteria
|
|
VALID_CRITERIA = {'aic', 'aicc', 'bic', 'hqic', 'oob'}
|
|
|
|
|
|
def auto_intercept(with_intercept, default):
|
|
"""A more concise way to handle the default behavior of with_intercept"""
|
|
if with_intercept == "auto":
|
|
return default
|
|
return with_intercept
|
|
|
|
|
|
def check_information_criterion(information_criterion, out_of_sample_size):
|
|
"""Check whether the information criterion is valid"""
|
|
if information_criterion not in VALID_CRITERIA:
|
|
raise ValueError('auto_arima not defined for information_criteria=%s. '
|
|
'Valid information criteria include: %r'
|
|
% (information_criterion, VALID_CRITERIA))
|
|
|
|
# check on information criterion and out_of_sample size
|
|
if information_criterion == 'oob' and out_of_sample_size == 0:
|
|
information_criterion = 'aic'
|
|
warnings.warn('information_criterion cannot be \'oob\' with '
|
|
'out_of_sample_size = 0. '
|
|
'Falling back to information criterion = aic.')
|
|
|
|
return information_criterion
|
|
|
|
|
|
def check_kwargs(kwargs):
|
|
"""Return kwargs or an empty dict.
|
|
|
|
We often pass named kwargs (like `sarimax_kwargs`) as None by default. This
|
|
is to avoid a mutable default, which can bite you in unexpected ways. This
|
|
will return a kwarg-compatible value.
|
|
"""
|
|
if kwargs:
|
|
return kwargs
|
|
return {}
|
|
|
|
|
|
def check_m(m, seasonal):
|
|
"""Check the value of M (seasonal periodicity)"""
|
|
if (m < 1 and seasonal) or m < 0:
|
|
raise ValueError('m must be a positive integer (> 0)')
|
|
|
|
if not seasonal:
|
|
# default m is 1, so if it's the default, don't warn
|
|
if m > 1:
|
|
warnings.warn("m (%i) set for non-seasonal fit. Setting to 0" % m)
|
|
m = 0
|
|
|
|
return m
|
|
|
|
|
|
def check_n_jobs(stepwise, n_jobs):
|
|
"""Potentially update the n_jobs parameter
|
|
|
|
We can't run in parallel with the stepwise algorithm. This checks
|
|
``n_jobs`` w.r.t. stepwise and will warn.
|
|
"""
|
|
if stepwise and n_jobs != 1:
|
|
n_jobs = 1
|
|
warnings.warn('stepwise model cannot be fit in parallel (n_jobs=%i). '
|
|
'Falling back to stepwise parameter search.' % n_jobs)
|
|
return n_jobs
|
|
|
|
|
|
def check_start_max_values(st, mx, argname):
|
|
"""Ensure starting points and ending points are valid"""
|
|
if mx is None:
|
|
mx = np.inf
|
|
if st is None:
|
|
raise ValueError("start_%s cannot be None" % argname)
|
|
if st < 0:
|
|
raise ValueError("start_%s must be positive" % argname)
|
|
if mx < st:
|
|
raise ValueError("max_%s must be >= start_%s" % (argname, argname))
|
|
return st, mx
|
|
|
|
|
|
def check_trace(trace):
|
|
"""Check the value of trace"""
|
|
if trace is None:
|
|
return 0
|
|
if isinstance(trace, (int, bool)):
|
|
return int(trace)
|
|
# otherwise just be truthy with it
|
|
if trace:
|
|
return 1
|
|
return 0
|
|
|
|
|
|
def get_scoring_metric(metric):
|
|
"""Get a scoring metric by name, or passthrough a callable
|
|
|
|
Parameters
|
|
----------
|
|
metric : str or callable
|
|
A name of a scoring metric, or a custom callable function. If it is a
|
|
callable, it must adhere to the signature::
|
|
|
|
def func(y_true, y_pred)
|
|
|
|
Note that the ARIMA model selection seeks to MINIMIZE the score, and it
|
|
is up to the user to ensure that scoring methods that return maximizing
|
|
criteria (i.e., ``r2_score``) are wrapped in a function that will
|
|
return the negative value of the score.
|
|
"""
|
|
if isinstance(metric, str):
|
|
|
|
# XXX: legacy support, remap mse/mae to their long versions
|
|
if metric == "mse":
|
|
return metrics.mean_squared_error
|
|
if metric == "mae":
|
|
return metrics.mean_absolute_error
|
|
|
|
try:
|
|
return getattr(metrics, metric)
|
|
except AttributeError:
|
|
raise ValueError("'%s' is not a valid scoring method." % metric)
|
|
|
|
if not callable(metric):
|
|
raise TypeError("`metric` must be a valid scoring method, or a "
|
|
"callable, but got type=%s" % type(metric))
|
|
|
|
# TODO: warn for potentially invalid signature?
|
|
return metric
|
|
|
|
|
|
def warn_for_D(d, D):
|
|
"""Warn for large values of D"""
|
|
if D >= 2:
|
|
warnings.warn("Having more than one seasonal differences is "
|
|
"not recommended. Please consider using only one "
|
|
"seasonal difference.", ModelFitWarning)
|
|
# if D is -1, this will be off, so we include the OR
|
|
# TODO: @FutureTayTay.. how can D be -1?
|
|
elif D + d > 2 or d > 2:
|
|
warnings.warn("Having 3 or more differencing operations is not "
|
|
"recommended. Please consider reducing the total "
|
|
"number of differences.", ModelFitWarning)
|