some new features

This commit is contained in:
ilgazca
2025-07-30 17:09:11 +03:00
parent db5d46760a
commit 8019bd3b7c
20616 changed files with 4375466 additions and 8 deletions

View File

@ -0,0 +1,101 @@
""" Test for assert_deallocated context manager and gc utilities
"""
import gc
from scipy._lib._gcutils import (set_gc_state, gc_state, assert_deallocated,
ReferenceError, IS_PYPY)
from numpy.testing import assert_equal
import pytest
def test_set_gc_state():
gc_status = gc.isenabled()
try:
for state in (True, False):
gc.enable()
set_gc_state(state)
assert_equal(gc.isenabled(), state)
gc.disable()
set_gc_state(state)
assert_equal(gc.isenabled(), state)
finally:
if gc_status:
gc.enable()
def test_gc_state():
# Test gc_state context manager
gc_status = gc.isenabled()
try:
for pre_state in (True, False):
set_gc_state(pre_state)
for with_state in (True, False):
# Check the gc state is with_state in with block
with gc_state(with_state):
assert_equal(gc.isenabled(), with_state)
# And returns to previous state outside block
assert_equal(gc.isenabled(), pre_state)
# Even if the gc state is set explicitly within the block
with gc_state(with_state):
assert_equal(gc.isenabled(), with_state)
set_gc_state(not with_state)
assert_equal(gc.isenabled(), pre_state)
finally:
if gc_status:
gc.enable()
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated():
# Ordinary use
class C:
def __init__(self, arg0, arg1, name='myname'):
self.name = name
for gc_current in (True, False):
with gc_state(gc_current):
# We are deleting from with-block context, so that's OK
with assert_deallocated(C, 0, 2, 'another name') as c:
assert_equal(c.name, 'another name')
del c
# Or not using the thing in with-block context, also OK
with assert_deallocated(C, 0, 2, name='third name'):
pass
assert_equal(gc.isenabled(), gc_current)
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_nodel():
class C:
pass
with pytest.raises(ReferenceError):
# Need to delete after using if in with-block context
# Note: assert_deallocated(C) needs to be assigned for the test
# to function correctly. It is assigned to _, but _ itself is
# not referenced in the body of the with, it is only there for
# the refcount.
with assert_deallocated(C) as _:
pass
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_circular():
class C:
def __init__(self):
self._circular = self
with pytest.raises(ReferenceError):
# Circular reference, no automatic garbage collection
with assert_deallocated(C) as c:
del c
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_circular2():
class C:
def __init__(self):
self._circular = self
with pytest.raises(ReferenceError):
# Still circular reference, no automatic garbage collection
with assert_deallocated(C):
pass

View File

@ -0,0 +1,67 @@
from pytest import raises as assert_raises
from scipy._lib._pep440 import Version, parse
def test_main_versions():
assert Version('1.8.0') == Version('1.8.0')
for ver in ['1.9.0', '2.0.0', '1.8.1']:
assert Version('1.8.0') < Version(ver)
for ver in ['1.7.0', '1.7.1', '0.9.9']:
assert Version('1.8.0') > Version(ver)
def test_version_1_point_10():
# regression test for gh-2998.
assert Version('1.9.0') < Version('1.10.0')
assert Version('1.11.0') < Version('1.11.1')
assert Version('1.11.0') == Version('1.11.0')
assert Version('1.99.11') < Version('1.99.12')
def test_alpha_beta_rc():
assert Version('1.8.0rc1') == Version('1.8.0rc1')
for ver in ['1.8.0', '1.8.0rc2']:
assert Version('1.8.0rc1') < Version(ver)
for ver in ['1.8.0a2', '1.8.0b3', '1.7.2rc4']:
assert Version('1.8.0rc1') > Version(ver)
assert Version('1.8.0b1') > Version('1.8.0a2')
def test_dev_version():
assert Version('1.9.0.dev+Unknown') < Version('1.9.0')
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev+ffffffff', '1.9.0.dev1']:
assert Version('1.9.0.dev+f16acvda') < Version(ver)
assert Version('1.9.0.dev+f16acvda') == Version('1.9.0.dev+f16acvda')
def test_dev_a_b_rc_mixed():
assert Version('1.9.0a2.dev+f16acvda') == Version('1.9.0a2.dev+f16acvda')
assert Version('1.9.0a2.dev+6acvda54') < Version('1.9.0a2')
def test_dev0_version():
assert Version('1.9.0.dev0+Unknown') < Version('1.9.0')
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev0+ffffffff']:
assert Version('1.9.0.dev0+f16acvda') < Version(ver)
assert Version('1.9.0.dev0+f16acvda') == Version('1.9.0.dev0+f16acvda')
def test_dev0_a_b_rc_mixed():
assert Version('1.9.0a2.dev0+f16acvda') == Version('1.9.0a2.dev0+f16acvda')
assert Version('1.9.0a2.dev0+6acvda54') < Version('1.9.0a2')
def test_raises():
for ver in ['1,9.0', '1.7.x']:
assert_raises(ValueError, Version, ver)
def test_legacy_version():
# Non-PEP-440 version identifiers always compare less. For NumPy this only
# occurs on dev builds prior to 1.10.0 which are unsupported anyway.
assert parse('invalid') < Version('0.0.0')
assert parse('1.9.0-f16acvda') < Version('1.0.0')

View File

@ -0,0 +1,32 @@
import sys
from scipy._lib._testutils import _parse_size, _get_mem_available
import pytest
def test__parse_size():
expected = {
'12': 12e6,
'12 b': 12,
'12k': 12e3,
' 12 M ': 12e6,
' 12 G ': 12e9,
' 12Tb ': 12e12,
'12 Mib ': 12 * 1024.0**2,
'12Tib': 12 * 1024.0**4,
}
for inp, outp in sorted(expected.items()):
if outp is None:
with pytest.raises(ValueError):
_parse_size(inp)
else:
assert _parse_size(inp) == outp
def test__mem_available():
# May return None on non-Linux platforms
available = _get_mem_available()
if sys.platform.startswith('linux'):
assert available >= 0
else:
assert available is None or available >= 0

View File

@ -0,0 +1,51 @@
import threading
import time
import traceback
from numpy.testing import assert_
from pytest import raises as assert_raises
from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
def test_parallel_threads():
# Check that ReentrancyLock serializes work in parallel threads.
#
# The test is not fully deterministic, and may succeed falsely if
# the timings go wrong.
lock = ReentrancyLock("failure")
failflag = [False]
exceptions_raised = []
def worker(k):
try:
with lock:
assert_(not failflag[0])
failflag[0] = True
time.sleep(0.1 * k)
assert_(failflag[0])
failflag[0] = False
except Exception:
exceptions_raised.append(traceback.format_exc(2))
threads = [threading.Thread(target=lambda k=k: worker(k))
for k in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
exceptions_raised = "\n".join(exceptions_raised)
assert_(not exceptions_raised, exceptions_raised)
def test_reentering():
# Check that ReentrancyLock prevents re-entering from the same thread.
@non_reentrant()
def func(x):
return func(x)
assert_raises(ReentrancyError, func, 0)

View File

@ -0,0 +1,447 @@
from multiprocessing import Pool
from multiprocessing.pool import Pool as PWL
import re
import math
from fractions import Fraction
import numpy as np
from numpy.testing import assert_equal, assert_
import pytest
from pytest import raises as assert_raises
import hypothesis.extra.numpy as npst
from hypothesis import given, strategies, reproduce_failure # noqa: F401
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
from scipy._lib._array_api import (xp_assert_equal, xp_assert_close, is_numpy,
copy as xp_copy)
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
getfullargspec_no_self, FullArgSpec,
rng_integers, _validate_int, _rename_parameter,
_contains_nan, _rng_html_rewrite, _lazywhere)
skip_xp_backends = pytest.mark.skip_xp_backends
@pytest.mark.slow
def test__aligned_zeros():
niter = 10
def check(shape, dtype, order, align):
err_msg = repr((shape, dtype, order, align))
x = _aligned_zeros(shape, dtype, order, align=align)
if align is None:
align = np.dtype(dtype).alignment
assert_equal(x.__array_interface__['data'][0] % align, 0)
if hasattr(shape, '__len__'):
assert_equal(x.shape, shape, err_msg)
else:
assert_equal(x.shape, (shape,), err_msg)
assert_equal(x.dtype, dtype)
if order == "C":
assert_(x.flags.c_contiguous, err_msg)
elif order == "F":
if x.size > 0:
# Size-0 arrays get invalid flags on NumPy 1.5
assert_(x.flags.f_contiguous, err_msg)
elif order is None:
assert_(x.flags.c_contiguous, err_msg)
else:
raise ValueError()
# try various alignments
for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
for n in [0, 1, 3, 11]:
for order in ["C", "F", None]:
for dtype in [np.uint8, np.float64]:
for shape in [n, (1, 2, 3, n)]:
for j in range(niter):
check(shape, dtype, order, align)
def test_check_random_state():
# If seed is None, return the RandomState singleton used by np.random.
# If seed is an int, return a new RandomState instance seeded with seed.
# If seed is already a RandomState instance, return it.
# Otherwise raise ValueError.
rsi = check_random_state(1)
assert_equal(type(rsi), np.random.RandomState)
rsi = check_random_state(rsi)
assert_equal(type(rsi), np.random.RandomState)
rsi = check_random_state(None)
assert_equal(type(rsi), np.random.RandomState)
assert_raises(ValueError, check_random_state, 'a')
rg = np.random.Generator(np.random.PCG64())
rsi = check_random_state(rg)
assert_equal(type(rsi), np.random.Generator)
def test_getfullargspec_no_self():
p = MapWrapper(1)
argspec = getfullargspec_no_self(p.__init__)
assert_equal(argspec, FullArgSpec(['pool'], None, None, (1,), [],
None, {}))
argspec = getfullargspec_no_self(p.__call__)
assert_equal(argspec, FullArgSpec(['func', 'iterable'], None, None, None,
[], None, {}))
class _rv_generic:
def _rvs(self, a, b=2, c=3, *args, size=None, **kwargs):
return None
rv_obj = _rv_generic()
argspec = getfullargspec_no_self(rv_obj._rvs)
assert_equal(argspec, FullArgSpec(['a', 'b', 'c'], 'args', 'kwargs',
(2, 3), ['size'], {'size': None}, {}))
def test_mapwrapper_serial():
in_arg = np.arange(10.)
out_arg = np.sin(in_arg)
p = MapWrapper(1)
assert_(p._mapfunc is map)
assert_(p.pool is None)
assert_(p._own_pool is False)
out = list(p(np.sin, in_arg))
assert_equal(out, out_arg)
with assert_raises(RuntimeError):
p = MapWrapper(0)
def test_pool():
with Pool(2) as p:
p.map(math.sin, [1, 2, 3, 4])
def test_mapwrapper_parallel():
in_arg = np.arange(10.)
out_arg = np.sin(in_arg)
with MapWrapper(2) as p:
out = p(np.sin, in_arg)
assert_equal(list(out), out_arg)
assert_(p._own_pool is True)
assert_(isinstance(p.pool, PWL))
assert_(p._mapfunc is not None)
# the context manager should've closed the internal pool
# check that it has by asking it to calculate again.
with assert_raises(Exception) as excinfo:
p(np.sin, in_arg)
assert_(excinfo.type is ValueError)
# can also set a PoolWrapper up with a map-like callable instance
with Pool(2) as p:
q = MapWrapper(p.map)
assert_(q._own_pool is False)
q.close()
# closing the PoolWrapper shouldn't close the internal pool
# because it didn't create it
out = p.map(np.sin, in_arg)
assert_equal(list(out), out_arg)
def test_rng_integers():
rng = np.random.RandomState()
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 0
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 0
assert arr.shape == (100, )
# now try with np.random.Generator
try:
rng = np.random.default_rng()
except AttributeError:
return
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 0
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 0
assert arr.shape == (100, )
class TestValidateInt:
@pytest.mark.parametrize('n', [4, np.uint8(4), np.int16(4), np.array(4)])
def test_validate_int(self, n):
n = _validate_int(n, 'n')
assert n == 4
@pytest.mark.parametrize('n', [4.0, np.array([4]), Fraction(4, 1)])
def test_validate_int_bad(self, n):
with pytest.raises(TypeError, match='n must be an integer'):
_validate_int(n, 'n')
def test_validate_int_below_min(self):
with pytest.raises(ValueError, match='n must be an integer not '
'less than 0'):
_validate_int(-1, 'n', 0)
class TestRenameParameter:
# check that wrapper `_rename_parameter` for backward-compatible
# keyword renaming works correctly
# Example method/function that still accepts keyword `old`
@_rename_parameter("old", "new")
def old_keyword_still_accepted(self, new):
return new
# Example method/function for which keyword `old` is deprecated
@_rename_parameter("old", "new", dep_version="1.9.0")
def old_keyword_deprecated(self, new):
return new
def test_old_keyword_still_accepted(self):
# positional argument and both keyword work identically
res1 = self.old_keyword_still_accepted(10)
res2 = self.old_keyword_still_accepted(new=10)
res3 = self.old_keyword_still_accepted(old=10)
assert res1 == res2 == res3 == 10
# unexpected keyword raises an error
message = re.escape("old_keyword_still_accepted() got an unexpected")
with pytest.raises(TypeError, match=message):
self.old_keyword_still_accepted(unexpected=10)
# multiple values for the same parameter raises an error
message = re.escape("old_keyword_still_accepted() got multiple")
with pytest.raises(TypeError, match=message):
self.old_keyword_still_accepted(10, new=10)
with pytest.raises(TypeError, match=message):
self.old_keyword_still_accepted(10, old=10)
with pytest.raises(TypeError, match=message):
self.old_keyword_still_accepted(new=10, old=10)
def test_old_keyword_deprecated(self):
# positional argument and both keyword work identically,
# but use of old keyword results in DeprecationWarning
dep_msg = "Use of keyword argument `old` is deprecated"
res1 = self.old_keyword_deprecated(10)
res2 = self.old_keyword_deprecated(new=10)
with pytest.warns(DeprecationWarning, match=dep_msg):
res3 = self.old_keyword_deprecated(old=10)
assert res1 == res2 == res3 == 10
# unexpected keyword raises an error
message = re.escape("old_keyword_deprecated() got an unexpected")
with pytest.raises(TypeError, match=message):
self.old_keyword_deprecated(unexpected=10)
# multiple values for the same parameter raises an error and,
# if old keyword is used, results in DeprecationWarning
message = re.escape("old_keyword_deprecated() got multiple")
with pytest.raises(TypeError, match=message):
self.old_keyword_deprecated(10, new=10)
with pytest.raises(TypeError, match=message), \
pytest.warns(DeprecationWarning, match=dep_msg):
self.old_keyword_deprecated(10, old=10)
with pytest.raises(TypeError, match=message), \
pytest.warns(DeprecationWarning, match=dep_msg):
self.old_keyword_deprecated(new=10, old=10)
class TestContainsNaNTest:
def test_policy(self):
data = np.array([1, 2, 3, np.nan])
contains_nan, nan_policy = _contains_nan(data, nan_policy="propagate")
assert contains_nan
assert nan_policy == "propagate"
contains_nan, nan_policy = _contains_nan(data, nan_policy="omit")
assert contains_nan
assert nan_policy == "omit"
msg = "The input contains nan values"
with pytest.raises(ValueError, match=msg):
_contains_nan(data, nan_policy="raise")
msg = "nan_policy must be one of"
with pytest.raises(ValueError, match=msg):
_contains_nan(data, nan_policy="nan")
def test_contains_nan(self):
data1 = np.array([1, 2, 3])
assert not _contains_nan(data1)[0]
data2 = np.array([1, 2, 3, np.nan])
assert _contains_nan(data2)[0]
data3 = np.array([np.nan, 2, 3, np.nan])
assert _contains_nan(data3)[0]
data4 = np.array([[1, 2], [3, 4]])
assert not _contains_nan(data4)[0]
data5 = np.array([[1, 2], [3, np.nan]])
assert _contains_nan(data5)[0]
@skip_xp_invalid_arg
def test_contains_nan_with_strings(self):
data1 = np.array([1, 2, "3", np.nan]) # converted to string "nan"
assert not _contains_nan(data1)[0]
data2 = np.array([1, 2, "3", np.nan], dtype='object')
assert _contains_nan(data2)[0]
data3 = np.array([["1", 2], [3, np.nan]]) # converted to string "nan"
assert not _contains_nan(data3)[0]
data4 = np.array([["1", 2], [3, np.nan]], dtype='object')
assert _contains_nan(data4)[0]
@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
@pytest.mark.parametrize("nan_policy", ['propagate', 'omit', 'raise'])
def test_array_api(self, xp, nan_policy):
rng = np.random.default_rng(932347235892482)
x0 = rng.random(size=(2, 3, 4))
x = xp.asarray(x0)
x_nan = xp_copy(x, xp=xp)
x_nan[1, 2, 1] = np.nan
contains_nan, nan_policy_out = _contains_nan(x, nan_policy=nan_policy)
assert not contains_nan
assert nan_policy_out == nan_policy
if nan_policy == 'raise':
message = 'The input contains...'
with pytest.raises(ValueError, match=message):
_contains_nan(x_nan, nan_policy=nan_policy)
elif nan_policy == 'omit' and not is_numpy(xp):
message = "`nan_policy='omit' is incompatible..."
with pytest.raises(ValueError, match=message):
_contains_nan(x_nan, nan_policy=nan_policy)
elif nan_policy == 'propagate':
contains_nan, nan_policy_out = _contains_nan(
x_nan, nan_policy=nan_policy)
assert contains_nan
assert nan_policy_out == nan_policy
def test__rng_html_rewrite():
def mock_str():
lines = [
'np.random.default_rng(8989843)',
'np.random.default_rng(seed)',
'np.random.default_rng(0x9a71b21474694f919882289dc1559ca)',
' bob ',
]
return lines
res = _rng_html_rewrite(mock_str)()
ref = [
'np.random.default_rng()',
'np.random.default_rng(seed)',
'np.random.default_rng()',
' bob ',
]
assert res == ref
class TestLazywhere:
n_arrays = strategies.integers(min_value=1, max_value=3)
rng_seed = strategies.integers(min_value=1000000000, max_value=9999999999)
dtype = strategies.sampled_from((np.float32, np.float64))
p = strategies.floats(min_value=0, max_value=1)
data = strategies.data()
@pytest.mark.fail_slow(5)
@pytest.mark.filterwarnings('ignore::RuntimeWarning') # overflows, etc.
@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
@given(n_arrays=n_arrays, rng_seed=rng_seed, dtype=dtype, p=p, data=data)
def test_basic(self, n_arrays, rng_seed, dtype, p, data, xp):
mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays+1,
min_side=0)
input_shapes, result_shape = data.draw(mbs)
cond_shape, *shapes = input_shapes
fillvalue = xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=tuple())))
arrays = [xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=shape)))
for shape in shapes]
def f(*args):
return sum(arg for arg in args)
def f2(*args):
return sum(arg for arg in args) / 2
rng = np.random.default_rng(rng_seed)
cond = xp.asarray(rng.random(size=cond_shape) > p)
res1 = _lazywhere(cond, arrays, f, fillvalue)
res2 = _lazywhere(cond, arrays, f, f2=f2)
# Ensure arrays are at least 1d to follow sane type promotion rules.
if xp == np:
cond, fillvalue, *arrays = np.atleast_1d(cond, fillvalue, *arrays)
ref1 = xp.where(cond, f(*arrays), fillvalue)
ref2 = xp.where(cond, f(*arrays), f2(*arrays))
if xp == np:
ref1 = ref1.reshape(result_shape)
ref2 = ref2.reshape(result_shape)
res1 = xp.asarray(res1)[()]
res2 = xp.asarray(res2)[()]
isinstance(res1, type(xp.asarray([])))
xp_assert_close(res1, ref1, rtol=2e-16)
assert_equal(res1.shape, ref1.shape)
assert_equal(res1.dtype, ref1.dtype)
isinstance(res2, type(xp.asarray([])))
xp_assert_equal(res2, ref2)
assert_equal(res2.shape, ref2.shape)
assert_equal(res2.dtype, ref2.dtype)

View File

@ -0,0 +1,114 @@
import numpy as np
import pytest
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import (
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
)
import scipy._lib.array_api_compat.numpy as np_compat
skip_xp_backends = pytest.mark.skip_xp_backends
@pytest.mark.skipif(not _GLOBAL_CONFIG["SCIPY_ARRAY_API"],
reason="Array API test; set environment variable SCIPY_ARRAY_API=1 to run it")
class TestArrayAPI:
def test_array_namespace(self):
x, y = np.array([0, 1, 2]), np.array([0, 1, 2])
xp = array_namespace(x, y)
assert 'array_api_compat.numpy' in xp.__name__
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False
xp = array_namespace(x, y)
assert 'array_api_compat.numpy' in xp.__name__
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True
@array_api_compatible
def test_asarray(self, xp):
x, y = _asarray([0, 1, 2], xp=xp), _asarray(np.arange(3), xp=xp)
ref = xp.asarray([0, 1, 2])
xp_assert_equal(x, ref)
xp_assert_equal(y, ref)
@pytest.mark.filterwarnings("ignore: the matrix subclass")
def test_raises(self):
msg = "of type `numpy.ma.MaskedArray` are not supported"
with pytest.raises(TypeError, match=msg):
array_namespace(np.ma.array(1), np.array(1))
msg = "of type `numpy.matrix` are not supported"
with pytest.raises(TypeError, match=msg):
array_namespace(np.array(1), np.matrix(1))
msg = "only boolean and numerical dtypes are supported"
with pytest.raises(TypeError, match=msg):
array_namespace([object()])
with pytest.raises(TypeError, match=msg):
array_namespace('abc')
def test_array_likes(self):
# should be no exceptions
array_namespace([0, 1, 2])
array_namespace(1, 2, 3)
array_namespace(1)
@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_copy(self, xp):
for _xp in [xp, None]:
x = xp.asarray([1, 2, 3])
y = copy(x, xp=_xp)
# with numpy we'd want to use np.shared_memory, but that's not specified
# in the array-api
x[0] = 10
x[1] = 11
x[2] = 12
assert x[0] != y[0]
assert x[1] != y[1]
assert x[2] != y[2]
assert id(x) != id(y)
@array_api_compatible
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float32', 'float64'])
@pytest.mark.parametrize('shape', [(), (3,)])
def test_strict_checks(self, xp, dtype, shape):
# Check that `_strict_check` behaves as expected
dtype = getattr(xp, dtype)
x = xp.broadcast_to(xp.asarray(1, dtype=dtype), shape)
x = x if shape else x[()]
y = np_compat.asarray(1)[()]
options = dict(check_namespace=True, check_dtype=False, check_shape=False)
if xp == np:
xp_assert_equal(x, y, **options)
else:
with pytest.raises(AssertionError, match="Namespaces do not match."):
xp_assert_equal(x, y, **options)
options = dict(check_namespace=False, check_dtype=True, check_shape=False)
if y.dtype.name in str(x.dtype):
xp_assert_equal(x, y, **options)
else:
with pytest.raises(AssertionError, match="dtypes do not match."):
xp_assert_equal(x, y, **options)
options = dict(check_namespace=False, check_dtype=False, check_shape=True)
if x.shape == y.shape:
xp_assert_equal(x, y, **options)
else:
with pytest.raises(AssertionError, match="Shapes do not match."):
xp_assert_equal(x, y, **options)
@array_api_compatible
def test_check_scalar(self, xp):
if not is_numpy(xp):
pytest.skip("Scalars only exist in NumPy")
if is_numpy(xp):
with pytest.raises(AssertionError, match="Types do not match."):
xp_assert_equal(xp.asarray(0.), xp.float64(0))
xp_assert_equal(xp.float64(0), xp.asarray(0.))

View File

@ -0,0 +1,162 @@
import pytest
import pickle
from numpy.testing import assert_equal
from scipy._lib._bunch import _make_tuple_bunch
# `Result` is defined at the top level of the module so it can be
# used to test pickling.
Result = _make_tuple_bunch('Result', ['x', 'y', 'z'], ['w', 'beta'])
class TestMakeTupleBunch:
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Tests with Result
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def setup_method(self):
# Set up an instance of Result.
self.result = Result(x=1, y=2, z=3, w=99, beta=0.5)
def test_attribute_access(self):
assert_equal(self.result.x, 1)
assert_equal(self.result.y, 2)
assert_equal(self.result.z, 3)
assert_equal(self.result.w, 99)
assert_equal(self.result.beta, 0.5)
def test_indexing(self):
assert_equal(self.result[0], 1)
assert_equal(self.result[1], 2)
assert_equal(self.result[2], 3)
assert_equal(self.result[-1], 3)
with pytest.raises(IndexError, match='index out of range'):
self.result[3]
def test_unpacking(self):
x0, y0, z0 = self.result
assert_equal((x0, y0, z0), (1, 2, 3))
assert_equal(self.result, (1, 2, 3))
def test_slice(self):
assert_equal(self.result[1:], (2, 3))
assert_equal(self.result[::2], (1, 3))
assert_equal(self.result[::-1], (3, 2, 1))
def test_len(self):
assert_equal(len(self.result), 3)
def test_repr(self):
s = repr(self.result)
assert_equal(s, 'Result(x=1, y=2, z=3, w=99, beta=0.5)')
def test_hash(self):
assert_equal(hash(self.result), hash((1, 2, 3)))
def test_pickle(self):
s = pickle.dumps(self.result)
obj = pickle.loads(s)
assert isinstance(obj, Result)
assert_equal(obj.x, self.result.x)
assert_equal(obj.y, self.result.y)
assert_equal(obj.z, self.result.z)
assert_equal(obj.w, self.result.w)
assert_equal(obj.beta, self.result.beta)
def test_read_only_existing(self):
with pytest.raises(AttributeError, match="can't set attribute"):
self.result.x = -1
def test_read_only_new(self):
self.result.plate_of_shrimp = "lattice of coincidence"
assert self.result.plate_of_shrimp == "lattice of coincidence"
def test_constructor_missing_parameter(self):
with pytest.raises(TypeError, match='missing'):
# `w` is missing.
Result(x=1, y=2, z=3, beta=0.75)
def test_constructor_incorrect_parameter(self):
with pytest.raises(TypeError, match='unexpected'):
# `foo` is not an existing field.
Result(x=1, y=2, z=3, w=123, beta=0.75, foo=999)
def test_module(self):
m = 'scipy._lib.tests.test_bunch'
assert_equal(Result.__module__, m)
assert_equal(self.result.__module__, m)
def test_extra_fields_per_instance(self):
# This test exists to ensure that instances of the same class
# store their own values for the extra fields. That is, the values
# are stored per instance and not in the class.
result1 = Result(x=1, y=2, z=3, w=-1, beta=0.0)
result2 = Result(x=4, y=5, z=6, w=99, beta=1.0)
assert_equal(result1.w, -1)
assert_equal(result1.beta, 0.0)
# The rest of these checks aren't essential, but let's check
# them anyway.
assert_equal(result1[:], (1, 2, 3))
assert_equal(result2.w, 99)
assert_equal(result2.beta, 1.0)
assert_equal(result2[:], (4, 5, 6))
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Other tests
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def test_extra_field_names_is_optional(self):
Square = _make_tuple_bunch('Square', ['width', 'height'])
sq = Square(width=1, height=2)
assert_equal(sq.width, 1)
assert_equal(sq.height, 2)
s = repr(sq)
assert_equal(s, 'Square(width=1, height=2)')
def test_tuple_like(self):
Tup = _make_tuple_bunch('Tup', ['a', 'b'])
tu = Tup(a=1, b=2)
assert isinstance(tu, tuple)
assert isinstance(tu + (1,), tuple)
def test_explicit_module(self):
m = 'some.module.name'
Foo = _make_tuple_bunch('Foo', ['x'], ['a', 'b'], module=m)
foo = Foo(x=1, a=355, b=113)
assert_equal(Foo.__module__, m)
assert_equal(foo.__module__, m)
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Argument validation
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@pytest.mark.parametrize('args', [('123', ['a'], ['b']),
('Foo', ['-3'], ['x']),
('Foo', ['a'], ['+-*/'])])
def test_identifiers_not_allowed(self, args):
with pytest.raises(ValueError, match='identifiers'):
_make_tuple_bunch(*args)
@pytest.mark.parametrize('args', [('Foo', ['a', 'b', 'a'], ['x']),
('Foo', ['a', 'b'], ['b', 'x'])])
def test_repeated_field_names(self, args):
with pytest.raises(ValueError, match='Duplicate'):
_make_tuple_bunch(*args)
@pytest.mark.parametrize('args', [('Foo', ['_a'], ['x']),
('Foo', ['a'], ['_x'])])
def test_leading_underscore_not_allowed(self, args):
with pytest.raises(ValueError, match='underscore'):
_make_tuple_bunch(*args)
@pytest.mark.parametrize('args', [('Foo', ['def'], ['x']),
('Foo', ['a'], ['or']),
('and', ['a'], ['x'])])
def test_keyword_not_allowed_in_fields(self, args):
with pytest.raises(ValueError, match='keyword'):
_make_tuple_bunch(*args)
def test_at_least_one_field_name_required(self):
with pytest.raises(ValueError, match='at least one name'):
_make_tuple_bunch('Qwerty', [], ['a', 'b'])

View File

@ -0,0 +1,204 @@
from numpy.testing import assert_equal, assert_
from pytest import raises as assert_raises
import time
import pytest
import ctypes
import threading
from scipy._lib import _ccallback_c as _test_ccallback_cython
from scipy._lib import _test_ccallback
from scipy._lib._ccallback import LowLevelCallable
try:
import cffi
HAVE_CFFI = True
except ImportError:
HAVE_CFFI = False
ERROR_VALUE = 2.0
def callback_python(a, user_data=None):
if a == ERROR_VALUE:
raise ValueError("bad value")
if user_data is None:
return a + 1
else:
return a + user_data
def _get_cffi_func(base, signature):
if not HAVE_CFFI:
pytest.skip("cffi not installed")
# Get function address
voidp = ctypes.cast(base, ctypes.c_void_p)
address = voidp.value
# Create corresponding cffi handle
ffi = cffi.FFI()
func = ffi.cast(signature, address)
return func
def _get_ctypes_data():
value = ctypes.c_double(2.0)
return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
def _get_cffi_data():
if not HAVE_CFFI:
pytest.skip("cffi not installed")
ffi = cffi.FFI()
return ffi.new('double *', 2.0)
CALLERS = {
'simple': _test_ccallback.test_call_simple,
'nodata': _test_ccallback.test_call_nodata,
'nonlocal': _test_ccallback.test_call_nonlocal,
'cython': _test_ccallback_cython.test_call_cython,
}
# These functions have signatures known to the callers
FUNCS = {
'python': lambda: callback_python,
'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
"plus1_cython"),
'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
'double (*)(double, int *, void *)'),
'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
"plus1b_cython"),
'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
'double (*)(double, double, int *, void *)'),
}
# These functions have signatures the callers don't know
BAD_FUNCS = {
'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
"plus1bc_cython"),
'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
'cffi_bc': lambda: _get_cffi_func(
_test_ccallback_cython.plus1bc_ctypes,
'double (*)(double, double, double, int *, void *)'
),
}
USER_DATAS = {
'ctypes': _get_ctypes_data,
'cffi': _get_cffi_data,
'capsule': _test_ccallback.test_get_data_capsule,
}
def test_callbacks():
def check(caller, func, user_data):
caller = CALLERS[caller]
func = FUNCS[func]()
user_data = USER_DATAS[user_data]()
if func is callback_python:
def func2(x):
return func(x, 2.0)
else:
func2 = LowLevelCallable(func, user_data)
func = LowLevelCallable(func)
# Test basic call
assert_equal(caller(func, 1.0), 2.0)
# Test 'bad' value resulting to an error
assert_raises(ValueError, caller, func, ERROR_VALUE)
# Test passing in user_data
assert_equal(caller(func2, 1.0), 3.0)
for caller in sorted(CALLERS.keys()):
for func in sorted(FUNCS.keys()):
for user_data in sorted(USER_DATAS.keys()):
check(caller, func, user_data)
def test_bad_callbacks():
def check(caller, func, user_data):
caller = CALLERS[caller]
user_data = USER_DATAS[user_data]()
func = BAD_FUNCS[func]()
if func is callback_python:
def func2(x):
return func(x, 2.0)
else:
func2 = LowLevelCallable(func, user_data)
func = LowLevelCallable(func)
# Test that basic call fails
assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
# Test that passing in user_data also fails
assert_raises(ValueError, caller, func2, 1.0)
# Test error message
llfunc = LowLevelCallable(func)
try:
caller(llfunc, 1.0)
except ValueError as err:
msg = str(err)
assert_(llfunc.signature in msg, msg)
assert_('double (double, double, int *, void *)' in msg, msg)
for caller in sorted(CALLERS.keys()):
for func in sorted(BAD_FUNCS.keys()):
for user_data in sorted(USER_DATAS.keys()):
check(caller, func, user_data)
def test_signature_override():
caller = _test_ccallback.test_call_simple
func = _test_ccallback.test_get_plus1_capsule()
llcallable = LowLevelCallable(func, signature="bad signature")
assert_equal(llcallable.signature, "bad signature")
assert_raises(ValueError, caller, llcallable, 3)
llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
assert_equal(llcallable.signature, "double (double, int *, void *)")
assert_equal(caller(llcallable, 3), 4)
def test_threadsafety():
def callback(a, caller):
if a <= 0:
return 1
else:
res = caller(lambda x: callback(x, caller), a - 1)
return 2*res
def check(caller):
caller = CALLERS[caller]
results = []
count = 10
def run():
time.sleep(0.01)
r = caller(lambda x: callback(x, caller), count)
results.append(r)
threads = [threading.Thread(target=run) for j in range(20)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert_equal(results, [2.0**count]*len(threads))
for caller in CALLERS.keys():
check(caller)

View File

@ -0,0 +1,10 @@
import pytest
def test_cython_api_deprecation():
match = ("`scipy._lib._test_deprecation_def.foo_deprecated` "
"is deprecated, use `foo` instead!\n"
"Deprecated in Scipy 42.0.0")
with pytest.warns(DeprecationWarning, match=match):
from .. import _test_deprecation_call
assert _test_deprecation_call.call() == (1, 1)

View File

@ -0,0 +1,17 @@
import pytest
import sys
import subprocess
from .test_public_api import PUBLIC_MODULES
# Regression tests for gh-6793.
# Check that all modules are importable in a new Python process.
# This is not necessarily true if there are import cycles present.
@pytest.mark.fail_slow(20)
@pytest.mark.slow
def test_public_modules_importable():
pids = [subprocess.Popen([sys.executable, '-c', f'import {module}'])
for module in PUBLIC_MODULES]
for i, pid in enumerate(pids):
assert pid.wait() == 0, f'Failed to import {PUBLIC_MODULES[i]}'

View File

@ -0,0 +1,496 @@
"""
This test script is adopted from:
https://github.com/numpy/numpy/blob/main/numpy/tests/test_public_api.py
"""
import pkgutil
import types
import importlib
import warnings
from importlib import import_module
import pytest
import scipy
from scipy.conftest import xp_available_backends
def test_dir_testing():
"""Assert that output of dir has only one "testing/tester"
attribute without duplicate"""
assert len(dir(scipy)) == len(set(dir(scipy)))
# Historically SciPy has not used leading underscores for private submodules
# much. This has resulted in lots of things that look like public modules
# (i.e. things that can be imported as `import scipy.somesubmodule.somefile`),
# but were never intended to be public. The PUBLIC_MODULES list contains
# modules that are either public because they were meant to be, or because they
# contain public functions/objects that aren't present in any other namespace
# for whatever reason and therefore should be treated as public.
PUBLIC_MODULES = ["scipy." + s for s in [
"cluster",
"cluster.vq",
"cluster.hierarchy",
"constants",
"datasets",
"fft",
"fftpack",
"integrate",
"interpolate",
"io",
"io.arff",
"io.matlab",
"io.wavfile",
"linalg",
"linalg.blas",
"linalg.cython_blas",
"linalg.lapack",
"linalg.cython_lapack",
"linalg.interpolative",
"misc",
"ndimage",
"odr",
"optimize",
"signal",
"signal.windows",
"sparse",
"sparse.linalg",
"sparse.csgraph",
"spatial",
"spatial.distance",
"spatial.transform",
"special",
"stats",
"stats.contingency",
"stats.distributions",
"stats.mstats",
"stats.qmc",
"stats.sampling"
]]
# The PRIVATE_BUT_PRESENT_MODULES list contains modules that lacked underscores
# in their name and hence looked public, but weren't meant to be. All these
# namespace were deprecated in the 1.8.0 release - see "clear split between
# public and private API" in the 1.8.0 release notes.
# These private modules support will be removed in SciPy v2.0.0, as the
# deprecation messages emitted by each of these modules say.
PRIVATE_BUT_PRESENT_MODULES = [
'scipy.constants.codata',
'scipy.constants.constants',
'scipy.fftpack.basic',
'scipy.fftpack.convolve',
'scipy.fftpack.helper',
'scipy.fftpack.pseudo_diffs',
'scipy.fftpack.realtransforms',
'scipy.integrate.dop',
'scipy.integrate.lsoda',
'scipy.integrate.odepack',
'scipy.integrate.quadpack',
'scipy.integrate.vode',
'scipy.interpolate.dfitpack',
'scipy.interpolate.fitpack',
'scipy.interpolate.fitpack2',
'scipy.interpolate.interpnd',
'scipy.interpolate.interpolate',
'scipy.interpolate.ndgriddata',
'scipy.interpolate.polyint',
'scipy.interpolate.rbf',
'scipy.io.arff.arffread',
'scipy.io.harwell_boeing',
'scipy.io.idl',
'scipy.io.matlab.byteordercodes',
'scipy.io.matlab.mio',
'scipy.io.matlab.mio4',
'scipy.io.matlab.mio5',
'scipy.io.matlab.mio5_params',
'scipy.io.matlab.mio5_utils',
'scipy.io.matlab.mio_utils',
'scipy.io.matlab.miobase',
'scipy.io.matlab.streams',
'scipy.io.mmio',
'scipy.io.netcdf',
'scipy.linalg.basic',
'scipy.linalg.decomp',
'scipy.linalg.decomp_cholesky',
'scipy.linalg.decomp_lu',
'scipy.linalg.decomp_qr',
'scipy.linalg.decomp_schur',
'scipy.linalg.decomp_svd',
'scipy.linalg.matfuncs',
'scipy.linalg.misc',
'scipy.linalg.special_matrices',
'scipy.misc.common',
'scipy.misc.doccer',
'scipy.ndimage.filters',
'scipy.ndimage.fourier',
'scipy.ndimage.interpolation',
'scipy.ndimage.measurements',
'scipy.ndimage.morphology',
'scipy.odr.models',
'scipy.odr.odrpack',
'scipy.optimize.cobyla',
'scipy.optimize.cython_optimize',
'scipy.optimize.lbfgsb',
'scipy.optimize.linesearch',
'scipy.optimize.minpack',
'scipy.optimize.minpack2',
'scipy.optimize.moduleTNC',
'scipy.optimize.nonlin',
'scipy.optimize.optimize',
'scipy.optimize.slsqp',
'scipy.optimize.tnc',
'scipy.optimize.zeros',
'scipy.signal.bsplines',
'scipy.signal.filter_design',
'scipy.signal.fir_filter_design',
'scipy.signal.lti_conversion',
'scipy.signal.ltisys',
'scipy.signal.signaltools',
'scipy.signal.spectral',
'scipy.signal.spline',
'scipy.signal.waveforms',
'scipy.signal.wavelets',
'scipy.signal.windows.windows',
'scipy.sparse.base',
'scipy.sparse.bsr',
'scipy.sparse.compressed',
'scipy.sparse.construct',
'scipy.sparse.coo',
'scipy.sparse.csc',
'scipy.sparse.csr',
'scipy.sparse.data',
'scipy.sparse.dia',
'scipy.sparse.dok',
'scipy.sparse.extract',
'scipy.sparse.lil',
'scipy.sparse.linalg.dsolve',
'scipy.sparse.linalg.eigen',
'scipy.sparse.linalg.interface',
'scipy.sparse.linalg.isolve',
'scipy.sparse.linalg.matfuncs',
'scipy.sparse.sparsetools',
'scipy.sparse.spfuncs',
'scipy.sparse.sputils',
'scipy.spatial.ckdtree',
'scipy.spatial.kdtree',
'scipy.spatial.qhull',
'scipy.spatial.transform.rotation',
'scipy.special.add_newdocs',
'scipy.special.basic',
'scipy.special.cython_special',
'scipy.special.orthogonal',
'scipy.special.sf_error',
'scipy.special.specfun',
'scipy.special.spfun_stats',
'scipy.stats.biasedurn',
'scipy.stats.kde',
'scipy.stats.morestats',
'scipy.stats.mstats_basic',
'scipy.stats.mstats_extras',
'scipy.stats.mvn',
'scipy.stats.stats',
]
def is_unexpected(name):
"""Check if this needs to be considered."""
if '._' in name or '.tests' in name or '.setup' in name:
return False
if name in PUBLIC_MODULES:
return False
if name in PRIVATE_BUT_PRESENT_MODULES:
return False
return True
SKIP_LIST = [
'scipy.conftest',
'scipy.version',
'scipy.special.libsf_error_state'
]
# XXX: this test does more than it says on the tin - in using `pkgutil.walk_packages`,
# it will raise if it encounters any exceptions which are not handled by `ignore_errors`
# while attempting to import each discovered package.
# For now, `ignore_errors` only ignores what is necessary, but this could be expanded -
# for example, to all errors from private modules or git subpackages - if desired.
def test_all_modules_are_expected():
"""
Test that we don't add anything that looks like a new public module by
accident. Check is based on filenames.
"""
def ignore_errors(name):
# if versions of other array libraries are installed which are incompatible
# with the installed NumPy version, there can be errors on importing
# `array_api_compat`. This should only raise if SciPy is configured with
# that library as an available backend.
backends = {'cupy': 'cupy',
'pytorch': 'torch',
'dask.array': 'dask.array'}
for backend, dir_name in backends.items():
path = f'array_api_compat.{dir_name}'
if path in name and backend not in xp_available_backends:
return
raise
modnames = []
for _, modname, _ in pkgutil.walk_packages(path=scipy.__path__,
prefix=scipy.__name__ + '.',
onerror=ignore_errors):
if is_unexpected(modname) and modname not in SKIP_LIST:
# We have a name that is new. If that's on purpose, add it to
# PUBLIC_MODULES. We don't expect to have to add anything to
# PRIVATE_BUT_PRESENT_MODULES. Use an underscore in the name!
modnames.append(modname)
if modnames:
raise AssertionError(f'Found unexpected modules: {modnames}')
# Stuff that clearly shouldn't be in the API and is detected by the next test
# below
SKIP_LIST_2 = [
'scipy.char',
'scipy.rec',
'scipy.emath',
'scipy.math',
'scipy.random',
'scipy.ctypeslib',
'scipy.ma'
]
def test_all_modules_are_expected_2():
"""
Method checking all objects. The pkgutil-based method in
`test_all_modules_are_expected` does not catch imports into a namespace,
only filenames.
"""
def find_unexpected_members(mod_name):
members = []
module = importlib.import_module(mod_name)
if hasattr(module, '__all__'):
objnames = module.__all__
else:
objnames = dir(module)
for objname in objnames:
if not objname.startswith('_'):
fullobjname = mod_name + '.' + objname
if isinstance(getattr(module, objname), types.ModuleType):
if is_unexpected(fullobjname) and fullobjname not in SKIP_LIST_2:
members.append(fullobjname)
return members
unexpected_members = find_unexpected_members("scipy")
for modname in PUBLIC_MODULES:
unexpected_members.extend(find_unexpected_members(modname))
if unexpected_members:
raise AssertionError("Found unexpected object(s) that look like "
f"modules: {unexpected_members}")
def test_api_importable():
"""
Check that all submodules listed higher up in this file can be imported
Note that if a PRIVATE_BUT_PRESENT_MODULES entry goes missing, it may
simply need to be removed from the list (deprecation may or may not be
needed - apply common sense).
"""
def check_importable(module_name):
try:
importlib.import_module(module_name)
except (ImportError, AttributeError):
return False
return True
module_names = []
for module_name in PUBLIC_MODULES:
if not check_importable(module_name):
module_names.append(module_name)
if module_names:
raise AssertionError("Modules in the public API that cannot be "
f"imported: {module_names}")
with warnings.catch_warnings(record=True):
warnings.filterwarnings('always', category=DeprecationWarning)
warnings.filterwarnings('always', category=ImportWarning)
for module_name in PRIVATE_BUT_PRESENT_MODULES:
if not check_importable(module_name):
module_names.append(module_name)
if module_names:
raise AssertionError("Modules that are not really public but looked "
"public and can not be imported: "
f"{module_names}")
@pytest.mark.parametrize(("module_name", "correct_module"),
[('scipy.constants.codata', None),
('scipy.constants.constants', None),
('scipy.fftpack.basic', None),
('scipy.fftpack.helper', None),
('scipy.fftpack.pseudo_diffs', None),
('scipy.fftpack.realtransforms', None),
('scipy.integrate.dop', None),
('scipy.integrate.lsoda', None),
('scipy.integrate.odepack', None),
('scipy.integrate.quadpack', None),
('scipy.integrate.vode', None),
('scipy.interpolate.fitpack', None),
('scipy.interpolate.fitpack2', None),
('scipy.interpolate.interpolate', None),
('scipy.interpolate.ndgriddata', None),
('scipy.interpolate.polyint', None),
('scipy.interpolate.rbf', None),
('scipy.io.harwell_boeing', None),
('scipy.io.idl', None),
('scipy.io.mmio', None),
('scipy.io.netcdf', None),
('scipy.io.arff.arffread', 'arff'),
('scipy.io.matlab.byteordercodes', 'matlab'),
('scipy.io.matlab.mio_utils', 'matlab'),
('scipy.io.matlab.mio', 'matlab'),
('scipy.io.matlab.mio4', 'matlab'),
('scipy.io.matlab.mio5_params', 'matlab'),
('scipy.io.matlab.mio5_utils', 'matlab'),
('scipy.io.matlab.mio5', 'matlab'),
('scipy.io.matlab.miobase', 'matlab'),
('scipy.io.matlab.streams', 'matlab'),
('scipy.linalg.basic', None),
('scipy.linalg.decomp', None),
('scipy.linalg.decomp_cholesky', None),
('scipy.linalg.decomp_lu', None),
('scipy.linalg.decomp_qr', None),
('scipy.linalg.decomp_schur', None),
('scipy.linalg.decomp_svd', None),
('scipy.linalg.matfuncs', None),
('scipy.linalg.misc', None),
('scipy.linalg.special_matrices', None),
('scipy.misc.common', None),
('scipy.ndimage.filters', None),
('scipy.ndimage.fourier', None),
('scipy.ndimage.interpolation', None),
('scipy.ndimage.measurements', None),
('scipy.ndimage.morphology', None),
('scipy.odr.models', None),
('scipy.odr.odrpack', None),
('scipy.optimize.cobyla', None),
('scipy.optimize.lbfgsb', None),
('scipy.optimize.linesearch', None),
('scipy.optimize.minpack', None),
('scipy.optimize.minpack2', None),
('scipy.optimize.moduleTNC', None),
('scipy.optimize.nonlin', None),
('scipy.optimize.optimize', None),
('scipy.optimize.slsqp', None),
('scipy.optimize.tnc', None),
('scipy.optimize.zeros', None),
('scipy.signal.bsplines', None),
('scipy.signal.filter_design', None),
('scipy.signal.fir_filter_design', None),
('scipy.signal.lti_conversion', None),
('scipy.signal.ltisys', None),
('scipy.signal.signaltools', None),
('scipy.signal.spectral', None),
('scipy.signal.waveforms', None),
('scipy.signal.wavelets', None),
('scipy.signal.windows.windows', 'windows'),
('scipy.sparse.lil', None),
('scipy.sparse.linalg.dsolve', 'linalg'),
('scipy.sparse.linalg.eigen', 'linalg'),
('scipy.sparse.linalg.interface', 'linalg'),
('scipy.sparse.linalg.isolve', 'linalg'),
('scipy.sparse.linalg.matfuncs', 'linalg'),
('scipy.sparse.sparsetools', None),
('scipy.sparse.spfuncs', None),
('scipy.sparse.sputils', None),
('scipy.spatial.ckdtree', None),
('scipy.spatial.kdtree', None),
('scipy.spatial.qhull', None),
('scipy.spatial.transform.rotation', 'transform'),
('scipy.special.add_newdocs', None),
('scipy.special.basic', None),
('scipy.special.orthogonal', None),
('scipy.special.sf_error', None),
('scipy.special.specfun', None),
('scipy.special.spfun_stats', None),
('scipy.stats.biasedurn', None),
('scipy.stats.kde', None),
('scipy.stats.morestats', None),
('scipy.stats.mstats_basic', 'mstats'),
('scipy.stats.mstats_extras', 'mstats'),
('scipy.stats.mvn', None),
('scipy.stats.stats', None)])
def test_private_but_present_deprecation(module_name, correct_module):
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
# for imports from private modules
# were misleading. Check that this is resolved.
module = import_module(module_name)
if correct_module is None:
import_name = f'scipy.{module_name.split(".")[1]}'
else:
import_name = f'scipy.{module_name.split(".")[1]}.{correct_module}'
correct_import = import_module(import_name)
# Attributes that were formerly in `module_name` can still be imported from
# `module_name`, albeit with a deprecation warning.
for attr_name in module.__all__:
if attr_name == "varmats_from_mat":
# defer handling this case, see
# https://github.com/scipy/scipy/issues/19223
continue
# ensure attribute is present where the warning is pointing
assert getattr(correct_import, attr_name, None) is not None
message = f"Please import `{attr_name}` from the `{import_name}`..."
with pytest.deprecated_call(match=message):
getattr(module, attr_name)
# Attributes that were not in `module_name` get an error notifying the user
# that the attribute is not in `module_name` and that `module_name` is deprecated.
message = f"`{module_name}` is deprecated..."
with pytest.raises(AttributeError, match=message):
getattr(module, "ekki")
def test_misc_doccer_deprecation():
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
# for imports from private modules were misleading.
# Check that this is resolved.
# `test_private_but_present_deprecation` cannot be used since `correct_import`
# is a different subpackage (`_lib` instead of `misc`).
module = import_module('scipy.misc.doccer')
correct_import = import_module('scipy._lib.doccer')
# Attributes that were formerly in `scipy.misc.doccer` can still be imported from
# `scipy.misc.doccer`, albeit with a deprecation warning. The specific message
# depends on whether the attribute is in `scipy._lib.doccer` or not.
for attr_name in module.__all__:
attr = getattr(correct_import, attr_name, None)
if attr is None:
message = f"`scipy.misc.{attr_name}` is deprecated..."
else:
message = f"Please import `{attr_name}` from the `scipy._lib.doccer`..."
with pytest.deprecated_call(match=message):
getattr(module, attr_name)
# Attributes that were not in `scipy.misc.doccer` get an error
# notifying the user that the attribute is not in `scipy.misc.doccer`
# and that `scipy.misc.doccer` is deprecated.
message = "`scipy.misc.doccer` is deprecated..."
with pytest.raises(AttributeError, match=message):
getattr(module, "ekki")

View File

@ -0,0 +1,18 @@
import re
import scipy
from numpy.testing import assert_
def test_valid_scipy_version():
# Verify that the SciPy version is a valid one (no .post suffix or other
# nonsense). See NumPy issue gh-6431 for an issue caused by an invalid
# version.
version_pattern = r"^[0-9]+\.[0-9]+\.[0-9]+(|a[0-9]|b[0-9]|rc[0-9])"
dev_suffix = r"(\.dev0\+.+([0-9a-f]{7}|Unknown))"
if scipy.version.release:
res = re.match(version_pattern, scipy.__version__)
else:
res = re.match(version_pattern + dev_suffix, scipy.__version__)
assert_(res is not None, scipy.__version__)

View File

@ -0,0 +1,42 @@
""" Test tmpdirs module """
from os import getcwd
from os.path import realpath, abspath, dirname, isfile, join as pjoin, exists
from scipy._lib._tmpdirs import tempdir, in_tempdir, in_dir
from numpy.testing import assert_, assert_equal
MY_PATH = abspath(__file__)
MY_DIR = dirname(MY_PATH)
def test_tempdir():
with tempdir() as tmpdir:
fname = pjoin(tmpdir, 'example_file.txt')
with open(fname, "w") as fobj:
fobj.write('a string\\n')
assert_(not exists(tmpdir))
def test_in_tempdir():
my_cwd = getcwd()
with in_tempdir() as tmpdir:
with open('test.txt', "w") as f:
f.write('some text')
assert_(isfile('test.txt'))
assert_(isfile(pjoin(tmpdir, 'test.txt')))
assert_(not exists(tmpdir))
assert_equal(getcwd(), my_cwd)
def test_given_directory():
# Test InGivenDirectory
cwd = getcwd()
with in_dir() as tmpdir:
assert_equal(tmpdir, abspath(cwd))
assert_equal(tmpdir, abspath(getcwd()))
with in_dir(MY_DIR) as tmpdir:
assert_equal(tmpdir, MY_DIR)
assert_equal(realpath(MY_DIR), realpath(abspath(getcwd())))
# We were deleting the given directory! Check not so now.
assert_(isfile(MY_PATH))

View File

@ -0,0 +1,135 @@
"""
Tests which scan for certain occurrences in the code, they may not find
all of these occurrences but should catch almost all. This file was adapted
from NumPy.
"""
import os
from pathlib import Path
import ast
import tokenize
import scipy
import pytest
class ParseCall(ast.NodeVisitor):
def __init__(self):
self.ls = []
def visit_Attribute(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.ls.append(node.attr)
def visit_Name(self, node):
self.ls.append(node.id)
class FindFuncs(ast.NodeVisitor):
def __init__(self, filename):
super().__init__()
self.__filename = filename
self.bad_filters = []
self.bad_stacklevels = []
def visit_Call(self, node):
p = ParseCall()
p.visit(node.func)
ast.NodeVisitor.generic_visit(self, node)
if p.ls[-1] == 'simplefilter' or p.ls[-1] == 'filterwarnings':
# get first argument of the `args` node of the filter call
match node.args[0]:
case ast.Constant() as c:
argtext = c.value
case ast.JoinedStr() as js:
# if we get an f-string, discard the templated pieces, which
# are likely the type or specific message; we're interested
# in the action, which is less likely to use a template
argtext = "".join(
x.value for x in js.values if isinstance(x, ast.Constant)
)
case _:
raise ValueError("unknown ast node type")
# check if filter is set to ignore
if argtext == "ignore":
self.bad_filters.append(
f"{self.__filename}:{node.lineno}")
if p.ls[-1] == 'warn' and (
len(p.ls) == 1 or p.ls[-2] == 'warnings'):
if self.__filename == "_lib/tests/test_warnings.py":
# This file
return
# See if stacklevel exists:
if len(node.args) == 3:
return
args = {kw.arg for kw in node.keywords}
if "stacklevel" not in args:
self.bad_stacklevels.append(
f"{self.__filename}:{node.lineno}")
@pytest.fixture(scope="session")
def warning_calls():
# combined "ignore" and stacklevel error
base = Path(scipy.__file__).parent
bad_filters = []
bad_stacklevels = []
for path in base.rglob("*.py"):
# use tokenize to auto-detect encoding on systems where no
# default encoding is defined (e.g., LANG='C')
with tokenize.open(str(path)) as file:
tree = ast.parse(file.read(), filename=str(path))
finder = FindFuncs(path.relative_to(base))
finder.visit(tree)
bad_filters.extend(finder.bad_filters)
bad_stacklevels.extend(finder.bad_stacklevels)
return bad_filters, bad_stacklevels
@pytest.mark.fail_slow(20)
@pytest.mark.slow
def test_warning_calls_filters(warning_calls):
bad_filters, bad_stacklevels = warning_calls
# We try not to add filters in the code base, because those filters aren't
# thread-safe. We aim to only filter in tests with
# np.testing.suppress_warnings. However, in some cases it may prove
# necessary to filter out warnings, because we can't (easily) fix the root
# cause for them and we don't want users to see some warnings when they use
# SciPy correctly. So we list exceptions here. Add new entries only if
# there's a good reason.
allowed_filters = (
os.path.join('datasets', '_fetchers.py'),
os.path.join('datasets', '__init__.py'),
os.path.join('optimize', '_optimize.py'),
os.path.join('optimize', '_constraints.py'),
os.path.join('optimize', '_nnls.py'),
os.path.join('signal', '_ltisys.py'),
os.path.join('sparse', '__init__.py'), # np.matrix pending-deprecation
os.path.join('stats', '_discrete_distns.py'), # gh-14901
os.path.join('stats', '_continuous_distns.py'),
os.path.join('stats', '_binned_statistic.py'), # gh-19345
os.path.join('stats', 'tests', 'test_axis_nan_policy.py'), # gh-20694
os.path.join('_lib', '_util.py'), # gh-19341
os.path.join('sparse', 'linalg', '_dsolve', 'linsolve.py'), # gh-17924
"conftest.py",
)
bad_filters = [item for item in bad_filters if item.split(':')[0] not in
allowed_filters]
if bad_filters:
raise AssertionError(
"warning ignore filter should not be used, instead, use\n"
"numpy.testing.suppress_warnings (in tests only);\n"
"found in:\n {}".format(
"\n ".join(bad_filters)))