35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from pmdarima.arima import utils as arima_utils
|
|
from pmdarima.compat.pytest import pytest_warning_messages, pytest_error_str
|
|
|
|
|
|
def test_issue_341():
|
|
seas_diffed = np.array([124., -114., -163., -83.])
|
|
|
|
with pytest.raises(ValueError) as ve:
|
|
arima_utils.ndiffs(seas_diffed, test='adf')
|
|
|
|
assert "raised from LinAlgError" in pytest_error_str(ve)
|
|
|
|
|
|
def test_issue_351():
|
|
y = np.array([
|
|
1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 1, 6, 2, 1, 0,
|
|
2, 0, 1, 0, 0, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 6,
|
|
0, 0, 0, 0, 0, 1, 3, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0
|
|
])
|
|
|
|
with pytest.warns(UserWarning) as w_list:
|
|
D = arima_utils.nsdiffs(y, m=52, max_D=2, test='ocsb')
|
|
|
|
assert D == 1
|
|
|
|
warnings_messages = pytest_warning_messages(w_list)
|
|
assert len(warnings_messages) == 1
|
|
assert 'shorter than m' in warnings_messages[0]
|