143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from pmdarima.arima.auto import StepwiseContext, auto_arima
|
|
from pmdarima.arima._context import ContextStore, ContextType
|
|
from pmdarima.arima import _context as context_lib
|
|
from pmdarima.datasets import load_lynx, load_wineind
|
|
|
|
from unittest import mock
|
|
import threading
|
|
import collections
|
|
import pytest
|
|
import warnings
|
|
|
|
lynx = load_lynx()
|
|
wineind = load_wineind()
|
|
|
|
|
|
# test StepwiseContext parameter validation
|
|
@pytest.mark.parametrize(
|
|
'max_steps,max_dur', [
|
|
pytest.param(-1, None),
|
|
pytest.param(0, None),
|
|
pytest.param(1001, None),
|
|
pytest.param(1100, None),
|
|
pytest.param(None, -1),
|
|
pytest.param(None, 0),
|
|
])
|
|
def test_stepwise_context_args(max_steps, max_dur):
|
|
with pytest.raises(ValueError):
|
|
StepwiseContext(max_steps=max_steps, max_dur=max_dur)
|
|
|
|
|
|
# test auto_arima stepwise run with StepwiseContext
|
|
def test_auto_arima_with_stepwise_context():
|
|
samp = lynx[:8]
|
|
with StepwiseContext(max_steps=3, max_dur=30):
|
|
with pytest.warns(UserWarning) as uw:
|
|
auto_arima(samp, suppress_warnings=False, stepwise=True,
|
|
error_action='ignore')
|
|
|
|
# assert that max_steps were taken
|
|
assert any(str(w.message)
|
|
.startswith('stepwise search has reached the '
|
|
'maximum number of tries') for w in uw)
|
|
|
|
|
|
# test effective context info in nested context scenario
|
|
def test_nested_context():
|
|
ctx1_data = {'max_dur': 30}
|
|
ctx2_data = {'max_steps': 5}
|
|
ctx1 = StepwiseContext(**ctx1_data)
|
|
ctx2 = StepwiseContext(**ctx2_data)
|
|
|
|
with ctx1, ctx2:
|
|
effective_ctx_data = ContextStore.get_or_empty(
|
|
ContextType.STEPWISE)
|
|
expected_ctx_data = ctx1_data.copy()
|
|
expected_ctx_data.update(ctx2_data)
|
|
|
|
assert all(effective_ctx_data[key] == expected_ctx_data[key]
|
|
for key in expected_ctx_data.keys())
|
|
|
|
assert all(effective_ctx_data[key] == expected_ctx_data[key]
|
|
for key in effective_ctx_data.keys())
|
|
|
|
|
|
# Test a context honors the max duration
|
|
def test_max_dur():
|
|
# set arbitrarily low to guarantee will always pass after one iter
|
|
with StepwiseContext(max_dur=.5), \
|
|
pytest.warns(UserWarning) as uw:
|
|
|
|
auto_arima(lynx, stepwise=True)
|
|
# assert that max_dur was reached
|
|
assert any(str(w.message)
|
|
.startswith('early termination') for w in uw)
|
|
|
|
|
|
# Test that a context after the first will not inherit the first's attrs
|
|
def test_subsequent_contexts():
|
|
# Force a very fast fit
|
|
with StepwiseContext(max_dur=.5), \
|
|
pytest.warns(UserWarning):
|
|
auto_arima(lynx, stepwise=True)
|
|
|
|
# Out of scope, should be EMPTY
|
|
ctx = ContextStore.get_or_empty(ContextType.STEPWISE)
|
|
assert ctx.get_type() is ContextType.EMPTY
|
|
|
|
# Now show that we DON'T hit early termination by time here
|
|
with StepwiseContext(max_steps=100), \
|
|
warnings.catch_warnings(record=True) as uw:
|
|
|
|
ctx = ContextStore.get_or_empty(ContextType.STEPWISE)
|
|
assert ctx.get_type() is ContextType.STEPWISE
|
|
assert ctx.max_dur is None
|
|
|
|
auto_arima(lynx, stepwise=True)
|
|
# assert that max_dur was NOT reached
|
|
if uw:
|
|
assert not any(str(w.message)
|
|
.startswith('early termination') for w in uw)
|
|
|
|
|
|
# test param validation of ContextStore's add, get and remove members
|
|
def test_add_get_remove_context_args():
|
|
with pytest.raises(ValueError):
|
|
ContextStore._add_context(None)
|
|
|
|
with pytest.raises(ValueError):
|
|
ContextStore._remove_context(None)
|
|
|
|
with pytest.raises(ValueError):
|
|
ContextStore.get_context(None)
|
|
|
|
|
|
def test_context_store_accessible_across_threads():
|
|
# Make sure it's completely empty by patching it
|
|
d = {}
|
|
with mock.patch('pmdarima.arima._context._ctx.store', d):
|
|
|
|
# pushes onto the Context Store
|
|
def push(n):
|
|
# n is the number of times this has been executed before. If > 0,
|
|
# assert there is a context there
|
|
if n > 0:
|
|
assert len(context_lib._ctx.store[ContextType.STEPWISE]) == n
|
|
else:
|
|
context_lib._ctx.store[ContextType.STEPWISE] = \
|
|
collections.deque()
|
|
|
|
new_ctx = StepwiseContext()
|
|
context_lib._ctx.store[ContextType.STEPWISE].append(new_ctx)
|
|
assert len(context_lib._ctx.store[ContextType.STEPWISE]) == n + 1
|
|
|
|
for i in range(5):
|
|
t = threading.Thread(target=push, args=(i,))
|
|
t.start()
|
|
t.join(1) # it shouldn't take even close to this time
|
|
|
|
# Assert the mock has lifted
|
|
assert context_lib._ctx.store is not d
|