Files
2025-08-01 04:33:03 -04:00

35 lines
1.0 KiB
Python

# -*- coding: utf-8 -*-
import numpy as np
import pytest
from pmdarima.arima import utils as arima_utils
from pmdarima.compat.pytest import pytest_warning_messages, pytest_error_str
def test_issue_341():
seas_diffed = np.array([124., -114., -163., -83.])
with pytest.raises(ValueError) as ve:
arima_utils.ndiffs(seas_diffed, test='adf')
assert "raised from LinAlgError" in pytest_error_str(ve)
def test_issue_351():
y = np.array([
1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 1, 6, 2, 1, 0,
2, 0, 1, 0, 0, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 6,
0, 0, 0, 0, 0, 1, 3, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0
])
with pytest.warns(UserWarning) as w_list:
D = arima_utils.nsdiffs(y, m=52, max_D=2, test='ocsb')
assert D == 1
warnings_messages = pytest_warning_messages(w_list)
assert len(warnings_messages) == 1
assert 'shorter than m' in warnings_messages[0]