some new features
This commit is contained in:
@ -0,0 +1,142 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pmdarima.arima.auto import StepwiseContext, auto_arima
|
||||
from pmdarima.arima._context import ContextStore, ContextType
|
||||
from pmdarima.arima import _context as context_lib
|
||||
from pmdarima.datasets import load_lynx, load_wineind
|
||||
|
||||
from unittest import mock
|
||||
import threading
|
||||
import collections
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
lynx = load_lynx()
|
||||
wineind = load_wineind()
|
||||
|
||||
|
||||
# test StepwiseContext parameter validation
|
||||
@pytest.mark.parametrize(
|
||||
'max_steps,max_dur', [
|
||||
pytest.param(-1, None),
|
||||
pytest.param(0, None),
|
||||
pytest.param(1001, None),
|
||||
pytest.param(1100, None),
|
||||
pytest.param(None, -1),
|
||||
pytest.param(None, 0),
|
||||
])
|
||||
def test_stepwise_context_args(max_steps, max_dur):
|
||||
with pytest.raises(ValueError):
|
||||
StepwiseContext(max_steps=max_steps, max_dur=max_dur)
|
||||
|
||||
|
||||
# test auto_arima stepwise run with StepwiseContext
|
||||
def test_auto_arima_with_stepwise_context():
|
||||
samp = lynx[:8]
|
||||
with StepwiseContext(max_steps=3, max_dur=30):
|
||||
with pytest.warns(UserWarning) as uw:
|
||||
auto_arima(samp, suppress_warnings=False, stepwise=True,
|
||||
error_action='ignore')
|
||||
|
||||
# assert that max_steps were taken
|
||||
assert any(str(w.message)
|
||||
.startswith('stepwise search has reached the '
|
||||
'maximum number of tries') for w in uw)
|
||||
|
||||
|
||||
# test effective context info in nested context scenario
|
||||
def test_nested_context():
|
||||
ctx1_data = {'max_dur': 30}
|
||||
ctx2_data = {'max_steps': 5}
|
||||
ctx1 = StepwiseContext(**ctx1_data)
|
||||
ctx2 = StepwiseContext(**ctx2_data)
|
||||
|
||||
with ctx1, ctx2:
|
||||
effective_ctx_data = ContextStore.get_or_empty(
|
||||
ContextType.STEPWISE)
|
||||
expected_ctx_data = ctx1_data.copy()
|
||||
expected_ctx_data.update(ctx2_data)
|
||||
|
||||
assert all(effective_ctx_data[key] == expected_ctx_data[key]
|
||||
for key in expected_ctx_data.keys())
|
||||
|
||||
assert all(effective_ctx_data[key] == expected_ctx_data[key]
|
||||
for key in effective_ctx_data.keys())
|
||||
|
||||
|
||||
# Test a context honors the max duration
|
||||
def test_max_dur():
|
||||
# set arbitrarily low to guarantee will always pass after one iter
|
||||
with StepwiseContext(max_dur=.5), \
|
||||
pytest.warns(UserWarning) as uw:
|
||||
|
||||
auto_arima(lynx, stepwise=True)
|
||||
# assert that max_dur was reached
|
||||
assert any(str(w.message)
|
||||
.startswith('early termination') for w in uw)
|
||||
|
||||
|
||||
# Test that a context after the first will not inherit the first's attrs
|
||||
def test_subsequent_contexts():
|
||||
# Force a very fast fit
|
||||
with StepwiseContext(max_dur=.5), \
|
||||
pytest.warns(UserWarning):
|
||||
auto_arima(lynx, stepwise=True)
|
||||
|
||||
# Out of scope, should be EMPTY
|
||||
ctx = ContextStore.get_or_empty(ContextType.STEPWISE)
|
||||
assert ctx.get_type() is ContextType.EMPTY
|
||||
|
||||
# Now show that we DON'T hit early termination by time here
|
||||
with StepwiseContext(max_steps=100), \
|
||||
warnings.catch_warnings(record=True) as uw:
|
||||
|
||||
ctx = ContextStore.get_or_empty(ContextType.STEPWISE)
|
||||
assert ctx.get_type() is ContextType.STEPWISE
|
||||
assert ctx.max_dur is None
|
||||
|
||||
auto_arima(lynx, stepwise=True)
|
||||
# assert that max_dur was NOT reached
|
||||
if uw:
|
||||
assert not any(str(w.message)
|
||||
.startswith('early termination') for w in uw)
|
||||
|
||||
|
||||
# test param validation of ContextStore's add, get and remove members
|
||||
def test_add_get_remove_context_args():
|
||||
with pytest.raises(ValueError):
|
||||
ContextStore._add_context(None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ContextStore._remove_context(None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ContextStore.get_context(None)
|
||||
|
||||
|
||||
def test_context_store_accessible_across_threads():
|
||||
# Make sure it's completely empty by patching it
|
||||
d = {}
|
||||
with mock.patch('pmdarima.arima._context._ctx.store', d):
|
||||
|
||||
# pushes onto the Context Store
|
||||
def push(n):
|
||||
# n is the number of times this has been executed before. If > 0,
|
||||
# assert there is a context there
|
||||
if n > 0:
|
||||
assert len(context_lib._ctx.store[ContextType.STEPWISE]) == n
|
||||
else:
|
||||
context_lib._ctx.store[ContextType.STEPWISE] = \
|
||||
collections.deque()
|
||||
|
||||
new_ctx = StepwiseContext()
|
||||
context_lib._ctx.store[ContextType.STEPWISE].append(new_ctx)
|
||||
assert len(context_lib._ctx.store[ContextType.STEPWISE]) == n + 1
|
||||
|
||||
for i in range(5):
|
||||
t = threading.Thread(target=push, args=(i,))
|
||||
t.start()
|
||||
t.join(1) # it shouldn't take even close to this time
|
||||
|
||||
# Assert the mock has lifted
|
||||
assert context_lib._ctx.store is not d
|
||||
Reference in New Issue
Block a user