some new features
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,101 @@
|
||||
""" Test for assert_deallocated context manager and gc utilities
|
||||
"""
|
||||
import gc
|
||||
|
||||
from scipy._lib._gcutils import (set_gc_state, gc_state, assert_deallocated,
|
||||
ReferenceError, IS_PYPY)
|
||||
|
||||
from numpy.testing import assert_equal
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_set_gc_state():
|
||||
gc_status = gc.isenabled()
|
||||
try:
|
||||
for state in (True, False):
|
||||
gc.enable()
|
||||
set_gc_state(state)
|
||||
assert_equal(gc.isenabled(), state)
|
||||
gc.disable()
|
||||
set_gc_state(state)
|
||||
assert_equal(gc.isenabled(), state)
|
||||
finally:
|
||||
if gc_status:
|
||||
gc.enable()
|
||||
|
||||
|
||||
def test_gc_state():
|
||||
# Test gc_state context manager
|
||||
gc_status = gc.isenabled()
|
||||
try:
|
||||
for pre_state in (True, False):
|
||||
set_gc_state(pre_state)
|
||||
for with_state in (True, False):
|
||||
# Check the gc state is with_state in with block
|
||||
with gc_state(with_state):
|
||||
assert_equal(gc.isenabled(), with_state)
|
||||
# And returns to previous state outside block
|
||||
assert_equal(gc.isenabled(), pre_state)
|
||||
# Even if the gc state is set explicitly within the block
|
||||
with gc_state(with_state):
|
||||
assert_equal(gc.isenabled(), with_state)
|
||||
set_gc_state(not with_state)
|
||||
assert_equal(gc.isenabled(), pre_state)
|
||||
finally:
|
||||
if gc_status:
|
||||
gc.enable()
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated():
|
||||
# Ordinary use
|
||||
class C:
|
||||
def __init__(self, arg0, arg1, name='myname'):
|
||||
self.name = name
|
||||
for gc_current in (True, False):
|
||||
with gc_state(gc_current):
|
||||
# We are deleting from with-block context, so that's OK
|
||||
with assert_deallocated(C, 0, 2, 'another name') as c:
|
||||
assert_equal(c.name, 'another name')
|
||||
del c
|
||||
# Or not using the thing in with-block context, also OK
|
||||
with assert_deallocated(C, 0, 2, name='third name'):
|
||||
pass
|
||||
assert_equal(gc.isenabled(), gc_current)
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_nodel():
|
||||
class C:
|
||||
pass
|
||||
with pytest.raises(ReferenceError):
|
||||
# Need to delete after using if in with-block context
|
||||
# Note: assert_deallocated(C) needs to be assigned for the test
|
||||
# to function correctly. It is assigned to _, but _ itself is
|
||||
# not referenced in the body of the with, it is only there for
|
||||
# the refcount.
|
||||
with assert_deallocated(C) as _:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_circular():
|
||||
class C:
|
||||
def __init__(self):
|
||||
self._circular = self
|
||||
with pytest.raises(ReferenceError):
|
||||
# Circular reference, no automatic garbage collection
|
||||
with assert_deallocated(C) as c:
|
||||
del c
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_circular2():
|
||||
class C:
|
||||
def __init__(self):
|
||||
self._circular = self
|
||||
with pytest.raises(ReferenceError):
|
||||
# Still circular reference, no automatic garbage collection
|
||||
with assert_deallocated(C):
|
||||
pass
|
||||
@ -0,0 +1,67 @@
|
||||
from pytest import raises as assert_raises
|
||||
from scipy._lib._pep440 import Version, parse
|
||||
|
||||
|
||||
def test_main_versions():
|
||||
assert Version('1.8.0') == Version('1.8.0')
|
||||
for ver in ['1.9.0', '2.0.0', '1.8.1']:
|
||||
assert Version('1.8.0') < Version(ver)
|
||||
|
||||
for ver in ['1.7.0', '1.7.1', '0.9.9']:
|
||||
assert Version('1.8.0') > Version(ver)
|
||||
|
||||
|
||||
def test_version_1_point_10():
|
||||
# regression test for gh-2998.
|
||||
assert Version('1.9.0') < Version('1.10.0')
|
||||
assert Version('1.11.0') < Version('1.11.1')
|
||||
assert Version('1.11.0') == Version('1.11.0')
|
||||
assert Version('1.99.11') < Version('1.99.12')
|
||||
|
||||
|
||||
def test_alpha_beta_rc():
|
||||
assert Version('1.8.0rc1') == Version('1.8.0rc1')
|
||||
for ver in ['1.8.0', '1.8.0rc2']:
|
||||
assert Version('1.8.0rc1') < Version(ver)
|
||||
|
||||
for ver in ['1.8.0a2', '1.8.0b3', '1.7.2rc4']:
|
||||
assert Version('1.8.0rc1') > Version(ver)
|
||||
|
||||
assert Version('1.8.0b1') > Version('1.8.0a2')
|
||||
|
||||
|
||||
def test_dev_version():
|
||||
assert Version('1.9.0.dev+Unknown') < Version('1.9.0')
|
||||
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev+ffffffff', '1.9.0.dev1']:
|
||||
assert Version('1.9.0.dev+f16acvda') < Version(ver)
|
||||
|
||||
assert Version('1.9.0.dev+f16acvda') == Version('1.9.0.dev+f16acvda')
|
||||
|
||||
|
||||
def test_dev_a_b_rc_mixed():
|
||||
assert Version('1.9.0a2.dev+f16acvda') == Version('1.9.0a2.dev+f16acvda')
|
||||
assert Version('1.9.0a2.dev+6acvda54') < Version('1.9.0a2')
|
||||
|
||||
|
||||
def test_dev0_version():
|
||||
assert Version('1.9.0.dev0+Unknown') < Version('1.9.0')
|
||||
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev0+ffffffff']:
|
||||
assert Version('1.9.0.dev0+f16acvda') < Version(ver)
|
||||
|
||||
assert Version('1.9.0.dev0+f16acvda') == Version('1.9.0.dev0+f16acvda')
|
||||
|
||||
|
||||
def test_dev0_a_b_rc_mixed():
|
||||
assert Version('1.9.0a2.dev0+f16acvda') == Version('1.9.0a2.dev0+f16acvda')
|
||||
assert Version('1.9.0a2.dev0+6acvda54') < Version('1.9.0a2')
|
||||
|
||||
|
||||
def test_raises():
|
||||
for ver in ['1,9.0', '1.7.x']:
|
||||
assert_raises(ValueError, Version, ver)
|
||||
|
||||
def test_legacy_version():
|
||||
# Non-PEP-440 version identifiers always compare less. For NumPy this only
|
||||
# occurs on dev builds prior to 1.10.0 which are unsupported anyway.
|
||||
assert parse('invalid') < Version('0.0.0')
|
||||
assert parse('1.9.0-f16acvda') < Version('1.0.0')
|
||||
@ -0,0 +1,32 @@
|
||||
import sys
|
||||
from scipy._lib._testutils import _parse_size, _get_mem_available
|
||||
import pytest
|
||||
|
||||
|
||||
def test__parse_size():
|
||||
expected = {
|
||||
'12': 12e6,
|
||||
'12 b': 12,
|
||||
'12k': 12e3,
|
||||
' 12 M ': 12e6,
|
||||
' 12 G ': 12e9,
|
||||
' 12Tb ': 12e12,
|
||||
'12 Mib ': 12 * 1024.0**2,
|
||||
'12Tib': 12 * 1024.0**4,
|
||||
}
|
||||
|
||||
for inp, outp in sorted(expected.items()):
|
||||
if outp is None:
|
||||
with pytest.raises(ValueError):
|
||||
_parse_size(inp)
|
||||
else:
|
||||
assert _parse_size(inp) == outp
|
||||
|
||||
|
||||
def test__mem_available():
|
||||
# May return None on non-Linux platforms
|
||||
available = _get_mem_available()
|
||||
if sys.platform.startswith('linux'):
|
||||
assert available >= 0
|
||||
else:
|
||||
assert available is None or available >= 0
|
||||
@ -0,0 +1,51 @@
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from numpy.testing import assert_
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
|
||||
|
||||
|
||||
def test_parallel_threads():
|
||||
# Check that ReentrancyLock serializes work in parallel threads.
|
||||
#
|
||||
# The test is not fully deterministic, and may succeed falsely if
|
||||
# the timings go wrong.
|
||||
|
||||
lock = ReentrancyLock("failure")
|
||||
|
||||
failflag = [False]
|
||||
exceptions_raised = []
|
||||
|
||||
def worker(k):
|
||||
try:
|
||||
with lock:
|
||||
assert_(not failflag[0])
|
||||
failflag[0] = True
|
||||
time.sleep(0.1 * k)
|
||||
assert_(failflag[0])
|
||||
failflag[0] = False
|
||||
except Exception:
|
||||
exceptions_raised.append(traceback.format_exc(2))
|
||||
|
||||
threads = [threading.Thread(target=lambda k=k: worker(k))
|
||||
for k in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
exceptions_raised = "\n".join(exceptions_raised)
|
||||
assert_(not exceptions_raised, exceptions_raised)
|
||||
|
||||
|
||||
def test_reentering():
|
||||
# Check that ReentrancyLock prevents re-entering from the same thread.
|
||||
|
||||
@non_reentrant()
|
||||
def func(x):
|
||||
return func(x)
|
||||
|
||||
assert_raises(ReentrancyError, func, 0)
|
||||
@ -0,0 +1,447 @@
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing.pool import Pool as PWL
|
||||
import re
|
||||
import math
|
||||
from fractions import Fraction
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_equal, assert_
|
||||
import pytest
|
||||
from pytest import raises as assert_raises
|
||||
import hypothesis.extra.numpy as npst
|
||||
from hypothesis import given, strategies, reproduce_failure # noqa: F401
|
||||
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
|
||||
|
||||
from scipy._lib._array_api import (xp_assert_equal, xp_assert_close, is_numpy,
|
||||
copy as xp_copy)
|
||||
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
|
||||
getfullargspec_no_self, FullArgSpec,
|
||||
rng_integers, _validate_int, _rename_parameter,
|
||||
_contains_nan, _rng_html_rewrite, _lazywhere)
|
||||
|
||||
skip_xp_backends = pytest.mark.skip_xp_backends
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test__aligned_zeros():
|
||||
niter = 10
|
||||
|
||||
def check(shape, dtype, order, align):
|
||||
err_msg = repr((shape, dtype, order, align))
|
||||
x = _aligned_zeros(shape, dtype, order, align=align)
|
||||
if align is None:
|
||||
align = np.dtype(dtype).alignment
|
||||
assert_equal(x.__array_interface__['data'][0] % align, 0)
|
||||
if hasattr(shape, '__len__'):
|
||||
assert_equal(x.shape, shape, err_msg)
|
||||
else:
|
||||
assert_equal(x.shape, (shape,), err_msg)
|
||||
assert_equal(x.dtype, dtype)
|
||||
if order == "C":
|
||||
assert_(x.flags.c_contiguous, err_msg)
|
||||
elif order == "F":
|
||||
if x.size > 0:
|
||||
# Size-0 arrays get invalid flags on NumPy 1.5
|
||||
assert_(x.flags.f_contiguous, err_msg)
|
||||
elif order is None:
|
||||
assert_(x.flags.c_contiguous, err_msg)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
# try various alignments
|
||||
for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
|
||||
for n in [0, 1, 3, 11]:
|
||||
for order in ["C", "F", None]:
|
||||
for dtype in [np.uint8, np.float64]:
|
||||
for shape in [n, (1, 2, 3, n)]:
|
||||
for j in range(niter):
|
||||
check(shape, dtype, order, align)
|
||||
|
||||
|
||||
def test_check_random_state():
|
||||
# If seed is None, return the RandomState singleton used by np.random.
|
||||
# If seed is an int, return a new RandomState instance seeded with seed.
|
||||
# If seed is already a RandomState instance, return it.
|
||||
# Otherwise raise ValueError.
|
||||
rsi = check_random_state(1)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
rsi = check_random_state(rsi)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
rsi = check_random_state(None)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
assert_raises(ValueError, check_random_state, 'a')
|
||||
rg = np.random.Generator(np.random.PCG64())
|
||||
rsi = check_random_state(rg)
|
||||
assert_equal(type(rsi), np.random.Generator)
|
||||
|
||||
|
||||
def test_getfullargspec_no_self():
|
||||
p = MapWrapper(1)
|
||||
argspec = getfullargspec_no_self(p.__init__)
|
||||
assert_equal(argspec, FullArgSpec(['pool'], None, None, (1,), [],
|
||||
None, {}))
|
||||
argspec = getfullargspec_no_self(p.__call__)
|
||||
assert_equal(argspec, FullArgSpec(['func', 'iterable'], None, None, None,
|
||||
[], None, {}))
|
||||
|
||||
class _rv_generic:
|
||||
def _rvs(self, a, b=2, c=3, *args, size=None, **kwargs):
|
||||
return None
|
||||
|
||||
rv_obj = _rv_generic()
|
||||
argspec = getfullargspec_no_self(rv_obj._rvs)
|
||||
assert_equal(argspec, FullArgSpec(['a', 'b', 'c'], 'args', 'kwargs',
|
||||
(2, 3), ['size'], {'size': None}, {}))
|
||||
|
||||
|
||||
def test_mapwrapper_serial():
|
||||
in_arg = np.arange(10.)
|
||||
out_arg = np.sin(in_arg)
|
||||
|
||||
p = MapWrapper(1)
|
||||
assert_(p._mapfunc is map)
|
||||
assert_(p.pool is None)
|
||||
assert_(p._own_pool is False)
|
||||
out = list(p(np.sin, in_arg))
|
||||
assert_equal(out, out_arg)
|
||||
|
||||
with assert_raises(RuntimeError):
|
||||
p = MapWrapper(0)
|
||||
|
||||
|
||||
def test_pool():
|
||||
with Pool(2) as p:
|
||||
p.map(math.sin, [1, 2, 3, 4])
|
||||
|
||||
|
||||
def test_mapwrapper_parallel():
|
||||
in_arg = np.arange(10.)
|
||||
out_arg = np.sin(in_arg)
|
||||
|
||||
with MapWrapper(2) as p:
|
||||
out = p(np.sin, in_arg)
|
||||
assert_equal(list(out), out_arg)
|
||||
|
||||
assert_(p._own_pool is True)
|
||||
assert_(isinstance(p.pool, PWL))
|
||||
assert_(p._mapfunc is not None)
|
||||
|
||||
# the context manager should've closed the internal pool
|
||||
# check that it has by asking it to calculate again.
|
||||
with assert_raises(Exception) as excinfo:
|
||||
p(np.sin, in_arg)
|
||||
|
||||
assert_(excinfo.type is ValueError)
|
||||
|
||||
# can also set a PoolWrapper up with a map-like callable instance
|
||||
with Pool(2) as p:
|
||||
q = MapWrapper(p.map)
|
||||
|
||||
assert_(q._own_pool is False)
|
||||
q.close()
|
||||
|
||||
# closing the PoolWrapper shouldn't close the internal pool
|
||||
# because it didn't create it
|
||||
out = p.map(np.sin, in_arg)
|
||||
assert_equal(list(out), out_arg)
|
||||
|
||||
|
||||
def test_rng_integers():
|
||||
rng = np.random.RandomState()
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# now try with np.random.Generator
|
||||
try:
|
||||
rng = np.random.default_rng()
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
|
||||
class TestValidateInt:
|
||||
|
||||
@pytest.mark.parametrize('n', [4, np.uint8(4), np.int16(4), np.array(4)])
|
||||
def test_validate_int(self, n):
|
||||
n = _validate_int(n, 'n')
|
||||
assert n == 4
|
||||
|
||||
@pytest.mark.parametrize('n', [4.0, np.array([4]), Fraction(4, 1)])
|
||||
def test_validate_int_bad(self, n):
|
||||
with pytest.raises(TypeError, match='n must be an integer'):
|
||||
_validate_int(n, 'n')
|
||||
|
||||
def test_validate_int_below_min(self):
|
||||
with pytest.raises(ValueError, match='n must be an integer not '
|
||||
'less than 0'):
|
||||
_validate_int(-1, 'n', 0)
|
||||
|
||||
|
||||
class TestRenameParameter:
|
||||
# check that wrapper `_rename_parameter` for backward-compatible
|
||||
# keyword renaming works correctly
|
||||
|
||||
# Example method/function that still accepts keyword `old`
|
||||
@_rename_parameter("old", "new")
|
||||
def old_keyword_still_accepted(self, new):
|
||||
return new
|
||||
|
||||
# Example method/function for which keyword `old` is deprecated
|
||||
@_rename_parameter("old", "new", dep_version="1.9.0")
|
||||
def old_keyword_deprecated(self, new):
|
||||
return new
|
||||
|
||||
def test_old_keyword_still_accepted(self):
|
||||
# positional argument and both keyword work identically
|
||||
res1 = self.old_keyword_still_accepted(10)
|
||||
res2 = self.old_keyword_still_accepted(new=10)
|
||||
res3 = self.old_keyword_still_accepted(old=10)
|
||||
assert res1 == res2 == res3 == 10
|
||||
|
||||
# unexpected keyword raises an error
|
||||
message = re.escape("old_keyword_still_accepted() got an unexpected")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(unexpected=10)
|
||||
|
||||
# multiple values for the same parameter raises an error
|
||||
message = re.escape("old_keyword_still_accepted() got multiple")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(10, new=10)
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(10, old=10)
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(new=10, old=10)
|
||||
|
||||
def test_old_keyword_deprecated(self):
|
||||
# positional argument and both keyword work identically,
|
||||
# but use of old keyword results in DeprecationWarning
|
||||
dep_msg = "Use of keyword argument `old` is deprecated"
|
||||
res1 = self.old_keyword_deprecated(10)
|
||||
res2 = self.old_keyword_deprecated(new=10)
|
||||
with pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
res3 = self.old_keyword_deprecated(old=10)
|
||||
assert res1 == res2 == res3 == 10
|
||||
|
||||
# unexpected keyword raises an error
|
||||
message = re.escape("old_keyword_deprecated() got an unexpected")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_deprecated(unexpected=10)
|
||||
|
||||
# multiple values for the same parameter raises an error and,
|
||||
# if old keyword is used, results in DeprecationWarning
|
||||
message = re.escape("old_keyword_deprecated() got multiple")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_deprecated(10, new=10)
|
||||
with pytest.raises(TypeError, match=message), \
|
||||
pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
self.old_keyword_deprecated(10, old=10)
|
||||
with pytest.raises(TypeError, match=message), \
|
||||
pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
self.old_keyword_deprecated(new=10, old=10)
|
||||
|
||||
|
||||
class TestContainsNaNTest:
|
||||
|
||||
def test_policy(self):
|
||||
data = np.array([1, 2, 3, np.nan])
|
||||
|
||||
contains_nan, nan_policy = _contains_nan(data, nan_policy="propagate")
|
||||
assert contains_nan
|
||||
assert nan_policy == "propagate"
|
||||
|
||||
contains_nan, nan_policy = _contains_nan(data, nan_policy="omit")
|
||||
assert contains_nan
|
||||
assert nan_policy == "omit"
|
||||
|
||||
msg = "The input contains nan values"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
_contains_nan(data, nan_policy="raise")
|
||||
|
||||
msg = "nan_policy must be one of"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
_contains_nan(data, nan_policy="nan")
|
||||
|
||||
def test_contains_nan(self):
|
||||
data1 = np.array([1, 2, 3])
|
||||
assert not _contains_nan(data1)[0]
|
||||
|
||||
data2 = np.array([1, 2, 3, np.nan])
|
||||
assert _contains_nan(data2)[0]
|
||||
|
||||
data3 = np.array([np.nan, 2, 3, np.nan])
|
||||
assert _contains_nan(data3)[0]
|
||||
|
||||
data4 = np.array([[1, 2], [3, 4]])
|
||||
assert not _contains_nan(data4)[0]
|
||||
|
||||
data5 = np.array([[1, 2], [3, np.nan]])
|
||||
assert _contains_nan(data5)[0]
|
||||
|
||||
@skip_xp_invalid_arg
|
||||
def test_contains_nan_with_strings(self):
|
||||
data1 = np.array([1, 2, "3", np.nan]) # converted to string "nan"
|
||||
assert not _contains_nan(data1)[0]
|
||||
|
||||
data2 = np.array([1, 2, "3", np.nan], dtype='object')
|
||||
assert _contains_nan(data2)[0]
|
||||
|
||||
data3 = np.array([["1", 2], [3, np.nan]]) # converted to string "nan"
|
||||
assert not _contains_nan(data3)[0]
|
||||
|
||||
data4 = np.array([["1", 2], [3, np.nan]], dtype='object')
|
||||
assert _contains_nan(data4)[0]
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
@pytest.mark.parametrize("nan_policy", ['propagate', 'omit', 'raise'])
|
||||
def test_array_api(self, xp, nan_policy):
|
||||
rng = np.random.default_rng(932347235892482)
|
||||
x0 = rng.random(size=(2, 3, 4))
|
||||
x = xp.asarray(x0)
|
||||
x_nan = xp_copy(x, xp=xp)
|
||||
x_nan[1, 2, 1] = np.nan
|
||||
|
||||
contains_nan, nan_policy_out = _contains_nan(x, nan_policy=nan_policy)
|
||||
assert not contains_nan
|
||||
assert nan_policy_out == nan_policy
|
||||
|
||||
if nan_policy == 'raise':
|
||||
message = 'The input contains...'
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_contains_nan(x_nan, nan_policy=nan_policy)
|
||||
elif nan_policy == 'omit' and not is_numpy(xp):
|
||||
message = "`nan_policy='omit' is incompatible..."
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_contains_nan(x_nan, nan_policy=nan_policy)
|
||||
elif nan_policy == 'propagate':
|
||||
contains_nan, nan_policy_out = _contains_nan(
|
||||
x_nan, nan_policy=nan_policy)
|
||||
assert contains_nan
|
||||
assert nan_policy_out == nan_policy
|
||||
|
||||
|
||||
def test__rng_html_rewrite():
|
||||
def mock_str():
|
||||
lines = [
|
||||
'np.random.default_rng(8989843)',
|
||||
'np.random.default_rng(seed)',
|
||||
'np.random.default_rng(0x9a71b21474694f919882289dc1559ca)',
|
||||
' bob ',
|
||||
]
|
||||
return lines
|
||||
|
||||
res = _rng_html_rewrite(mock_str)()
|
||||
ref = [
|
||||
'np.random.default_rng()',
|
||||
'np.random.default_rng(seed)',
|
||||
'np.random.default_rng()',
|
||||
' bob ',
|
||||
]
|
||||
|
||||
assert res == ref
|
||||
|
||||
|
||||
class TestLazywhere:
|
||||
n_arrays = strategies.integers(min_value=1, max_value=3)
|
||||
rng_seed = strategies.integers(min_value=1000000000, max_value=9999999999)
|
||||
dtype = strategies.sampled_from((np.float32, np.float64))
|
||||
p = strategies.floats(min_value=0, max_value=1)
|
||||
data = strategies.data()
|
||||
|
||||
@pytest.mark.fail_slow(5)
|
||||
@pytest.mark.filterwarnings('ignore::RuntimeWarning') # overflows, etc.
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
@given(n_arrays=n_arrays, rng_seed=rng_seed, dtype=dtype, p=p, data=data)
|
||||
def test_basic(self, n_arrays, rng_seed, dtype, p, data, xp):
|
||||
mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays+1,
|
||||
min_side=0)
|
||||
input_shapes, result_shape = data.draw(mbs)
|
||||
cond_shape, *shapes = input_shapes
|
||||
fillvalue = xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=tuple())))
|
||||
arrays = [xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=shape)))
|
||||
for shape in shapes]
|
||||
|
||||
def f(*args):
|
||||
return sum(arg for arg in args)
|
||||
|
||||
def f2(*args):
|
||||
return sum(arg for arg in args) / 2
|
||||
|
||||
rng = np.random.default_rng(rng_seed)
|
||||
cond = xp.asarray(rng.random(size=cond_shape) > p)
|
||||
|
||||
res1 = _lazywhere(cond, arrays, f, fillvalue)
|
||||
res2 = _lazywhere(cond, arrays, f, f2=f2)
|
||||
|
||||
# Ensure arrays are at least 1d to follow sane type promotion rules.
|
||||
if xp == np:
|
||||
cond, fillvalue, *arrays = np.atleast_1d(cond, fillvalue, *arrays)
|
||||
|
||||
ref1 = xp.where(cond, f(*arrays), fillvalue)
|
||||
ref2 = xp.where(cond, f(*arrays), f2(*arrays))
|
||||
|
||||
if xp == np:
|
||||
ref1 = ref1.reshape(result_shape)
|
||||
ref2 = ref2.reshape(result_shape)
|
||||
res1 = xp.asarray(res1)[()]
|
||||
res2 = xp.asarray(res2)[()]
|
||||
|
||||
isinstance(res1, type(xp.asarray([])))
|
||||
xp_assert_close(res1, ref1, rtol=2e-16)
|
||||
assert_equal(res1.shape, ref1.shape)
|
||||
assert_equal(res1.dtype, ref1.dtype)
|
||||
|
||||
isinstance(res2, type(xp.asarray([])))
|
||||
xp_assert_equal(res2, ref2)
|
||||
assert_equal(res2.shape, ref2.shape)
|
||||
assert_equal(res2.dtype, ref2.dtype)
|
||||
@ -0,0 +1,114 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scipy.conftest import array_api_compatible
|
||||
from scipy._lib._array_api import (
|
||||
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
|
||||
)
|
||||
import scipy._lib.array_api_compat.numpy as np_compat
|
||||
|
||||
skip_xp_backends = pytest.mark.skip_xp_backends
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _GLOBAL_CONFIG["SCIPY_ARRAY_API"],
|
||||
reason="Array API test; set environment variable SCIPY_ARRAY_API=1 to run it")
|
||||
class TestArrayAPI:
|
||||
|
||||
def test_array_namespace(self):
|
||||
x, y = np.array([0, 1, 2]), np.array([0, 1, 2])
|
||||
xp = array_namespace(x, y)
|
||||
assert 'array_api_compat.numpy' in xp.__name__
|
||||
|
||||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False
|
||||
xp = array_namespace(x, y)
|
||||
assert 'array_api_compat.numpy' in xp.__name__
|
||||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True
|
||||
|
||||
@array_api_compatible
|
||||
def test_asarray(self, xp):
|
||||
x, y = _asarray([0, 1, 2], xp=xp), _asarray(np.arange(3), xp=xp)
|
||||
ref = xp.asarray([0, 1, 2])
|
||||
xp_assert_equal(x, ref)
|
||||
xp_assert_equal(y, ref)
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: the matrix subclass")
|
||||
def test_raises(self):
|
||||
msg = "of type `numpy.ma.MaskedArray` are not supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace(np.ma.array(1), np.array(1))
|
||||
|
||||
msg = "of type `numpy.matrix` are not supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace(np.array(1), np.matrix(1))
|
||||
|
||||
msg = "only boolean and numerical dtypes are supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace([object()])
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace('abc')
|
||||
|
||||
def test_array_likes(self):
|
||||
# should be no exceptions
|
||||
array_namespace([0, 1, 2])
|
||||
array_namespace(1, 2, 3)
|
||||
array_namespace(1)
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
def test_copy(self, xp):
|
||||
for _xp in [xp, None]:
|
||||
x = xp.asarray([1, 2, 3])
|
||||
y = copy(x, xp=_xp)
|
||||
# with numpy we'd want to use np.shared_memory, but that's not specified
|
||||
# in the array-api
|
||||
x[0] = 10
|
||||
x[1] = 11
|
||||
x[2] = 12
|
||||
|
||||
assert x[0] != y[0]
|
||||
assert x[1] != y[1]
|
||||
assert x[2] != y[2]
|
||||
assert id(x) != id(y)
|
||||
|
||||
@array_api_compatible
|
||||
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float32', 'float64'])
|
||||
@pytest.mark.parametrize('shape', [(), (3,)])
|
||||
def test_strict_checks(self, xp, dtype, shape):
|
||||
# Check that `_strict_check` behaves as expected
|
||||
dtype = getattr(xp, dtype)
|
||||
x = xp.broadcast_to(xp.asarray(1, dtype=dtype), shape)
|
||||
x = x if shape else x[()]
|
||||
y = np_compat.asarray(1)[()]
|
||||
|
||||
options = dict(check_namespace=True, check_dtype=False, check_shape=False)
|
||||
if xp == np:
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="Namespaces do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
options = dict(check_namespace=False, check_dtype=True, check_shape=False)
|
||||
if y.dtype.name in str(x.dtype):
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="dtypes do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
options = dict(check_namespace=False, check_dtype=False, check_shape=True)
|
||||
if x.shape == y.shape:
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="Shapes do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
@array_api_compatible
|
||||
def test_check_scalar(self, xp):
|
||||
if not is_numpy(xp):
|
||||
pytest.skip("Scalars only exist in NumPy")
|
||||
|
||||
if is_numpy(xp):
|
||||
with pytest.raises(AssertionError, match="Types do not match."):
|
||||
xp_assert_equal(xp.asarray(0.), xp.float64(0))
|
||||
xp_assert_equal(xp.float64(0), xp.asarray(0.))
|
||||
@ -0,0 +1,162 @@
|
||||
import pytest
|
||||
import pickle
|
||||
from numpy.testing import assert_equal
|
||||
from scipy._lib._bunch import _make_tuple_bunch
|
||||
|
||||
|
||||
# `Result` is defined at the top level of the module so it can be
|
||||
# used to test pickling.
|
||||
Result = _make_tuple_bunch('Result', ['x', 'y', 'z'], ['w', 'beta'])
|
||||
|
||||
|
||||
class TestMakeTupleBunch:
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Tests with Result
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
def setup_method(self):
|
||||
# Set up an instance of Result.
|
||||
self.result = Result(x=1, y=2, z=3, w=99, beta=0.5)
|
||||
|
||||
def test_attribute_access(self):
|
||||
assert_equal(self.result.x, 1)
|
||||
assert_equal(self.result.y, 2)
|
||||
assert_equal(self.result.z, 3)
|
||||
assert_equal(self.result.w, 99)
|
||||
assert_equal(self.result.beta, 0.5)
|
||||
|
||||
def test_indexing(self):
|
||||
assert_equal(self.result[0], 1)
|
||||
assert_equal(self.result[1], 2)
|
||||
assert_equal(self.result[2], 3)
|
||||
assert_equal(self.result[-1], 3)
|
||||
with pytest.raises(IndexError, match='index out of range'):
|
||||
self.result[3]
|
||||
|
||||
def test_unpacking(self):
|
||||
x0, y0, z0 = self.result
|
||||
assert_equal((x0, y0, z0), (1, 2, 3))
|
||||
assert_equal(self.result, (1, 2, 3))
|
||||
|
||||
def test_slice(self):
|
||||
assert_equal(self.result[1:], (2, 3))
|
||||
assert_equal(self.result[::2], (1, 3))
|
||||
assert_equal(self.result[::-1], (3, 2, 1))
|
||||
|
||||
def test_len(self):
|
||||
assert_equal(len(self.result), 3)
|
||||
|
||||
def test_repr(self):
|
||||
s = repr(self.result)
|
||||
assert_equal(s, 'Result(x=1, y=2, z=3, w=99, beta=0.5)')
|
||||
|
||||
def test_hash(self):
|
||||
assert_equal(hash(self.result), hash((1, 2, 3)))
|
||||
|
||||
def test_pickle(self):
|
||||
s = pickle.dumps(self.result)
|
||||
obj = pickle.loads(s)
|
||||
assert isinstance(obj, Result)
|
||||
assert_equal(obj.x, self.result.x)
|
||||
assert_equal(obj.y, self.result.y)
|
||||
assert_equal(obj.z, self.result.z)
|
||||
assert_equal(obj.w, self.result.w)
|
||||
assert_equal(obj.beta, self.result.beta)
|
||||
|
||||
def test_read_only_existing(self):
|
||||
with pytest.raises(AttributeError, match="can't set attribute"):
|
||||
self.result.x = -1
|
||||
|
||||
def test_read_only_new(self):
|
||||
self.result.plate_of_shrimp = "lattice of coincidence"
|
||||
assert self.result.plate_of_shrimp == "lattice of coincidence"
|
||||
|
||||
def test_constructor_missing_parameter(self):
|
||||
with pytest.raises(TypeError, match='missing'):
|
||||
# `w` is missing.
|
||||
Result(x=1, y=2, z=3, beta=0.75)
|
||||
|
||||
def test_constructor_incorrect_parameter(self):
|
||||
with pytest.raises(TypeError, match='unexpected'):
|
||||
# `foo` is not an existing field.
|
||||
Result(x=1, y=2, z=3, w=123, beta=0.75, foo=999)
|
||||
|
||||
def test_module(self):
|
||||
m = 'scipy._lib.tests.test_bunch'
|
||||
assert_equal(Result.__module__, m)
|
||||
assert_equal(self.result.__module__, m)
|
||||
|
||||
def test_extra_fields_per_instance(self):
|
||||
# This test exists to ensure that instances of the same class
|
||||
# store their own values for the extra fields. That is, the values
|
||||
# are stored per instance and not in the class.
|
||||
result1 = Result(x=1, y=2, z=3, w=-1, beta=0.0)
|
||||
result2 = Result(x=4, y=5, z=6, w=99, beta=1.0)
|
||||
assert_equal(result1.w, -1)
|
||||
assert_equal(result1.beta, 0.0)
|
||||
# The rest of these checks aren't essential, but let's check
|
||||
# them anyway.
|
||||
assert_equal(result1[:], (1, 2, 3))
|
||||
assert_equal(result2.w, 99)
|
||||
assert_equal(result2.beta, 1.0)
|
||||
assert_equal(result2[:], (4, 5, 6))
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Other tests
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
def test_extra_field_names_is_optional(self):
|
||||
Square = _make_tuple_bunch('Square', ['width', 'height'])
|
||||
sq = Square(width=1, height=2)
|
||||
assert_equal(sq.width, 1)
|
||||
assert_equal(sq.height, 2)
|
||||
s = repr(sq)
|
||||
assert_equal(s, 'Square(width=1, height=2)')
|
||||
|
||||
def test_tuple_like(self):
|
||||
Tup = _make_tuple_bunch('Tup', ['a', 'b'])
|
||||
tu = Tup(a=1, b=2)
|
||||
assert isinstance(tu, tuple)
|
||||
assert isinstance(tu + (1,), tuple)
|
||||
|
||||
def test_explicit_module(self):
|
||||
m = 'some.module.name'
|
||||
Foo = _make_tuple_bunch('Foo', ['x'], ['a', 'b'], module=m)
|
||||
foo = Foo(x=1, a=355, b=113)
|
||||
assert_equal(Foo.__module__, m)
|
||||
assert_equal(foo.__module__, m)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Argument validation
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
@pytest.mark.parametrize('args', [('123', ['a'], ['b']),
|
||||
('Foo', ['-3'], ['x']),
|
||||
('Foo', ['a'], ['+-*/'])])
|
||||
def test_identifiers_not_allowed(self, args):
|
||||
with pytest.raises(ValueError, match='identifiers'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['a', 'b', 'a'], ['x']),
|
||||
('Foo', ['a', 'b'], ['b', 'x'])])
|
||||
def test_repeated_field_names(self, args):
|
||||
with pytest.raises(ValueError, match='Duplicate'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['_a'], ['x']),
|
||||
('Foo', ['a'], ['_x'])])
|
||||
def test_leading_underscore_not_allowed(self, args):
|
||||
with pytest.raises(ValueError, match='underscore'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['def'], ['x']),
|
||||
('Foo', ['a'], ['or']),
|
||||
('and', ['a'], ['x'])])
|
||||
def test_keyword_not_allowed_in_fields(self, args):
|
||||
with pytest.raises(ValueError, match='keyword'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
def test_at_least_one_field_name_required(self):
|
||||
with pytest.raises(ValueError, match='at least one name'):
|
||||
_make_tuple_bunch('Qwerty', [], ['a', 'b'])
|
||||
@ -0,0 +1,204 @@
|
||||
from numpy.testing import assert_equal, assert_
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
import time
|
||||
import pytest
|
||||
import ctypes
|
||||
import threading
|
||||
from scipy._lib import _ccallback_c as _test_ccallback_cython
|
||||
from scipy._lib import _test_ccallback
|
||||
from scipy._lib._ccallback import LowLevelCallable
|
||||
|
||||
try:
|
||||
import cffi
|
||||
HAVE_CFFI = True
|
||||
except ImportError:
|
||||
HAVE_CFFI = False
|
||||
|
||||
|
||||
ERROR_VALUE = 2.0
|
||||
|
||||
|
||||
def callback_python(a, user_data=None):
|
||||
if a == ERROR_VALUE:
|
||||
raise ValueError("bad value")
|
||||
|
||||
if user_data is None:
|
||||
return a + 1
|
||||
else:
|
||||
return a + user_data
|
||||
|
||||
def _get_cffi_func(base, signature):
|
||||
if not HAVE_CFFI:
|
||||
pytest.skip("cffi not installed")
|
||||
|
||||
# Get function address
|
||||
voidp = ctypes.cast(base, ctypes.c_void_p)
|
||||
address = voidp.value
|
||||
|
||||
# Create corresponding cffi handle
|
||||
ffi = cffi.FFI()
|
||||
func = ffi.cast(signature, address)
|
||||
return func
|
||||
|
||||
|
||||
def _get_ctypes_data():
|
||||
value = ctypes.c_double(2.0)
|
||||
return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
|
||||
|
||||
|
||||
def _get_cffi_data():
|
||||
if not HAVE_CFFI:
|
||||
pytest.skip("cffi not installed")
|
||||
ffi = cffi.FFI()
|
||||
return ffi.new('double *', 2.0)
|
||||
|
||||
|
||||
CALLERS = {
|
||||
'simple': _test_ccallback.test_call_simple,
|
||||
'nodata': _test_ccallback.test_call_nodata,
|
||||
'nonlocal': _test_ccallback.test_call_nonlocal,
|
||||
'cython': _test_ccallback_cython.test_call_cython,
|
||||
}
|
||||
|
||||
# These functions have signatures known to the callers
|
||||
FUNCS = {
|
||||
'python': lambda: callback_python,
|
||||
'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
|
||||
'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1_cython"),
|
||||
'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
|
||||
'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
|
||||
'double (*)(double, int *, void *)'),
|
||||
'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
|
||||
'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1b_cython"),
|
||||
'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
|
||||
'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
|
||||
'double (*)(double, double, int *, void *)'),
|
||||
}
|
||||
|
||||
# These functions have signatures the callers don't know
|
||||
BAD_FUNCS = {
|
||||
'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
|
||||
'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1bc_cython"),
|
||||
'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
|
||||
'cffi_bc': lambda: _get_cffi_func(
|
||||
_test_ccallback_cython.plus1bc_ctypes,
|
||||
'double (*)(double, double, double, int *, void *)'
|
||||
),
|
||||
}
|
||||
|
||||
USER_DATAS = {
|
||||
'ctypes': _get_ctypes_data,
|
||||
'cffi': _get_cffi_data,
|
||||
'capsule': _test_ccallback.test_get_data_capsule,
|
||||
}
|
||||
|
||||
|
||||
def test_callbacks():
|
||||
def check(caller, func, user_data):
|
||||
caller = CALLERS[caller]
|
||||
func = FUNCS[func]()
|
||||
user_data = USER_DATAS[user_data]()
|
||||
|
||||
if func is callback_python:
|
||||
def func2(x):
|
||||
return func(x, 2.0)
|
||||
else:
|
||||
func2 = LowLevelCallable(func, user_data)
|
||||
func = LowLevelCallable(func)
|
||||
|
||||
# Test basic call
|
||||
assert_equal(caller(func, 1.0), 2.0)
|
||||
|
||||
# Test 'bad' value resulting to an error
|
||||
assert_raises(ValueError, caller, func, ERROR_VALUE)
|
||||
|
||||
# Test passing in user_data
|
||||
assert_equal(caller(func2, 1.0), 3.0)
|
||||
|
||||
for caller in sorted(CALLERS.keys()):
|
||||
for func in sorted(FUNCS.keys()):
|
||||
for user_data in sorted(USER_DATAS.keys()):
|
||||
check(caller, func, user_data)
|
||||
|
||||
|
||||
def test_bad_callbacks():
|
||||
def check(caller, func, user_data):
|
||||
caller = CALLERS[caller]
|
||||
user_data = USER_DATAS[user_data]()
|
||||
func = BAD_FUNCS[func]()
|
||||
|
||||
if func is callback_python:
|
||||
def func2(x):
|
||||
return func(x, 2.0)
|
||||
else:
|
||||
func2 = LowLevelCallable(func, user_data)
|
||||
func = LowLevelCallable(func)
|
||||
|
||||
# Test that basic call fails
|
||||
assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
|
||||
|
||||
# Test that passing in user_data also fails
|
||||
assert_raises(ValueError, caller, func2, 1.0)
|
||||
|
||||
# Test error message
|
||||
llfunc = LowLevelCallable(func)
|
||||
try:
|
||||
caller(llfunc, 1.0)
|
||||
except ValueError as err:
|
||||
msg = str(err)
|
||||
assert_(llfunc.signature in msg, msg)
|
||||
assert_('double (double, double, int *, void *)' in msg, msg)
|
||||
|
||||
for caller in sorted(CALLERS.keys()):
|
||||
for func in sorted(BAD_FUNCS.keys()):
|
||||
for user_data in sorted(USER_DATAS.keys()):
|
||||
check(caller, func, user_data)
|
||||
|
||||
|
||||
def test_signature_override():
|
||||
caller = _test_ccallback.test_call_simple
|
||||
func = _test_ccallback.test_get_plus1_capsule()
|
||||
|
||||
llcallable = LowLevelCallable(func, signature="bad signature")
|
||||
assert_equal(llcallable.signature, "bad signature")
|
||||
assert_raises(ValueError, caller, llcallable, 3)
|
||||
|
||||
llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
|
||||
assert_equal(llcallable.signature, "double (double, int *, void *)")
|
||||
assert_equal(caller(llcallable, 3), 4)
|
||||
|
||||
|
||||
def test_threadsafety():
|
||||
def callback(a, caller):
|
||||
if a <= 0:
|
||||
return 1
|
||||
else:
|
||||
res = caller(lambda x: callback(x, caller), a - 1)
|
||||
return 2*res
|
||||
|
||||
def check(caller):
|
||||
caller = CALLERS[caller]
|
||||
|
||||
results = []
|
||||
|
||||
count = 10
|
||||
|
||||
def run():
|
||||
time.sleep(0.01)
|
||||
r = caller(lambda x: callback(x, caller), count)
|
||||
results.append(r)
|
||||
|
||||
threads = [threading.Thread(target=run) for j in range(20)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert_equal(results, [2.0**count]*len(threads))
|
||||
|
||||
for caller in CALLERS.keys():
|
||||
check(caller)
|
||||
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def test_cython_api_deprecation():
|
||||
match = ("`scipy._lib._test_deprecation_def.foo_deprecated` "
|
||||
"is deprecated, use `foo` instead!\n"
|
||||
"Deprecated in Scipy 42.0.0")
|
||||
with pytest.warns(DeprecationWarning, match=match):
|
||||
from .. import _test_deprecation_call
|
||||
assert _test_deprecation_call.call() == (1, 1)
|
||||
@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
from .test_public_api import PUBLIC_MODULES
|
||||
|
||||
# Regression tests for gh-6793.
|
||||
# Check that all modules are importable in a new Python process.
|
||||
# This is not necessarily true if there are import cycles present.
|
||||
|
||||
@pytest.mark.fail_slow(20)
|
||||
@pytest.mark.slow
|
||||
def test_public_modules_importable():
|
||||
pids = [subprocess.Popen([sys.executable, '-c', f'import {module}'])
|
||||
for module in PUBLIC_MODULES]
|
||||
for i, pid in enumerate(pids):
|
||||
assert pid.wait() == 0, f'Failed to import {PUBLIC_MODULES[i]}'
|
||||
@ -0,0 +1,496 @@
|
||||
"""
|
||||
This test script is adopted from:
|
||||
https://github.com/numpy/numpy/blob/main/numpy/tests/test_public_api.py
|
||||
"""
|
||||
|
||||
import pkgutil
|
||||
import types
|
||||
import importlib
|
||||
import warnings
|
||||
from importlib import import_module
|
||||
|
||||
import pytest
|
||||
|
||||
import scipy
|
||||
|
||||
from scipy.conftest import xp_available_backends
|
||||
|
||||
|
||||
def test_dir_testing():
|
||||
"""Assert that output of dir has only one "testing/tester"
|
||||
attribute without duplicate"""
|
||||
assert len(dir(scipy)) == len(set(dir(scipy)))
|
||||
|
||||
|
||||
# Historically SciPy has not used leading underscores for private submodules
|
||||
# much. This has resulted in lots of things that look like public modules
|
||||
# (i.e. things that can be imported as `import scipy.somesubmodule.somefile`),
|
||||
# but were never intended to be public. The PUBLIC_MODULES list contains
|
||||
# modules that are either public because they were meant to be, or because they
|
||||
# contain public functions/objects that aren't present in any other namespace
|
||||
# for whatever reason and therefore should be treated as public.
|
||||
PUBLIC_MODULES = ["scipy." + s for s in [
|
||||
"cluster",
|
||||
"cluster.vq",
|
||||
"cluster.hierarchy",
|
||||
"constants",
|
||||
"datasets",
|
||||
"fft",
|
||||
"fftpack",
|
||||
"integrate",
|
||||
"interpolate",
|
||||
"io",
|
||||
"io.arff",
|
||||
"io.matlab",
|
||||
"io.wavfile",
|
||||
"linalg",
|
||||
"linalg.blas",
|
||||
"linalg.cython_blas",
|
||||
"linalg.lapack",
|
||||
"linalg.cython_lapack",
|
||||
"linalg.interpolative",
|
||||
"misc",
|
||||
"ndimage",
|
||||
"odr",
|
||||
"optimize",
|
||||
"signal",
|
||||
"signal.windows",
|
||||
"sparse",
|
||||
"sparse.linalg",
|
||||
"sparse.csgraph",
|
||||
"spatial",
|
||||
"spatial.distance",
|
||||
"spatial.transform",
|
||||
"special",
|
||||
"stats",
|
||||
"stats.contingency",
|
||||
"stats.distributions",
|
||||
"stats.mstats",
|
||||
"stats.qmc",
|
||||
"stats.sampling"
|
||||
]]
|
||||
|
||||
# The PRIVATE_BUT_PRESENT_MODULES list contains modules that lacked underscores
|
||||
# in their name and hence looked public, but weren't meant to be. All these
|
||||
# namespace were deprecated in the 1.8.0 release - see "clear split between
|
||||
# public and private API" in the 1.8.0 release notes.
|
||||
# These private modules support will be removed in SciPy v2.0.0, as the
|
||||
# deprecation messages emitted by each of these modules say.
|
||||
PRIVATE_BUT_PRESENT_MODULES = [
|
||||
'scipy.constants.codata',
|
||||
'scipy.constants.constants',
|
||||
'scipy.fftpack.basic',
|
||||
'scipy.fftpack.convolve',
|
||||
'scipy.fftpack.helper',
|
||||
'scipy.fftpack.pseudo_diffs',
|
||||
'scipy.fftpack.realtransforms',
|
||||
'scipy.integrate.dop',
|
||||
'scipy.integrate.lsoda',
|
||||
'scipy.integrate.odepack',
|
||||
'scipy.integrate.quadpack',
|
||||
'scipy.integrate.vode',
|
||||
'scipy.interpolate.dfitpack',
|
||||
'scipy.interpolate.fitpack',
|
||||
'scipy.interpolate.fitpack2',
|
||||
'scipy.interpolate.interpnd',
|
||||
'scipy.interpolate.interpolate',
|
||||
'scipy.interpolate.ndgriddata',
|
||||
'scipy.interpolate.polyint',
|
||||
'scipy.interpolate.rbf',
|
||||
'scipy.io.arff.arffread',
|
||||
'scipy.io.harwell_boeing',
|
||||
'scipy.io.idl',
|
||||
'scipy.io.matlab.byteordercodes',
|
||||
'scipy.io.matlab.mio',
|
||||
'scipy.io.matlab.mio4',
|
||||
'scipy.io.matlab.mio5',
|
||||
'scipy.io.matlab.mio5_params',
|
||||
'scipy.io.matlab.mio5_utils',
|
||||
'scipy.io.matlab.mio_utils',
|
||||
'scipy.io.matlab.miobase',
|
||||
'scipy.io.matlab.streams',
|
||||
'scipy.io.mmio',
|
||||
'scipy.io.netcdf',
|
||||
'scipy.linalg.basic',
|
||||
'scipy.linalg.decomp',
|
||||
'scipy.linalg.decomp_cholesky',
|
||||
'scipy.linalg.decomp_lu',
|
||||
'scipy.linalg.decomp_qr',
|
||||
'scipy.linalg.decomp_schur',
|
||||
'scipy.linalg.decomp_svd',
|
||||
'scipy.linalg.matfuncs',
|
||||
'scipy.linalg.misc',
|
||||
'scipy.linalg.special_matrices',
|
||||
'scipy.misc.common',
|
||||
'scipy.misc.doccer',
|
||||
'scipy.ndimage.filters',
|
||||
'scipy.ndimage.fourier',
|
||||
'scipy.ndimage.interpolation',
|
||||
'scipy.ndimage.measurements',
|
||||
'scipy.ndimage.morphology',
|
||||
'scipy.odr.models',
|
||||
'scipy.odr.odrpack',
|
||||
'scipy.optimize.cobyla',
|
||||
'scipy.optimize.cython_optimize',
|
||||
'scipy.optimize.lbfgsb',
|
||||
'scipy.optimize.linesearch',
|
||||
'scipy.optimize.minpack',
|
||||
'scipy.optimize.minpack2',
|
||||
'scipy.optimize.moduleTNC',
|
||||
'scipy.optimize.nonlin',
|
||||
'scipy.optimize.optimize',
|
||||
'scipy.optimize.slsqp',
|
||||
'scipy.optimize.tnc',
|
||||
'scipy.optimize.zeros',
|
||||
'scipy.signal.bsplines',
|
||||
'scipy.signal.filter_design',
|
||||
'scipy.signal.fir_filter_design',
|
||||
'scipy.signal.lti_conversion',
|
||||
'scipy.signal.ltisys',
|
||||
'scipy.signal.signaltools',
|
||||
'scipy.signal.spectral',
|
||||
'scipy.signal.spline',
|
||||
'scipy.signal.waveforms',
|
||||
'scipy.signal.wavelets',
|
||||
'scipy.signal.windows.windows',
|
||||
'scipy.sparse.base',
|
||||
'scipy.sparse.bsr',
|
||||
'scipy.sparse.compressed',
|
||||
'scipy.sparse.construct',
|
||||
'scipy.sparse.coo',
|
||||
'scipy.sparse.csc',
|
||||
'scipy.sparse.csr',
|
||||
'scipy.sparse.data',
|
||||
'scipy.sparse.dia',
|
||||
'scipy.sparse.dok',
|
||||
'scipy.sparse.extract',
|
||||
'scipy.sparse.lil',
|
||||
'scipy.sparse.linalg.dsolve',
|
||||
'scipy.sparse.linalg.eigen',
|
||||
'scipy.sparse.linalg.interface',
|
||||
'scipy.sparse.linalg.isolve',
|
||||
'scipy.sparse.linalg.matfuncs',
|
||||
'scipy.sparse.sparsetools',
|
||||
'scipy.sparse.spfuncs',
|
||||
'scipy.sparse.sputils',
|
||||
'scipy.spatial.ckdtree',
|
||||
'scipy.spatial.kdtree',
|
||||
'scipy.spatial.qhull',
|
||||
'scipy.spatial.transform.rotation',
|
||||
'scipy.special.add_newdocs',
|
||||
'scipy.special.basic',
|
||||
'scipy.special.cython_special',
|
||||
'scipy.special.orthogonal',
|
||||
'scipy.special.sf_error',
|
||||
'scipy.special.specfun',
|
||||
'scipy.special.spfun_stats',
|
||||
'scipy.stats.biasedurn',
|
||||
'scipy.stats.kde',
|
||||
'scipy.stats.morestats',
|
||||
'scipy.stats.mstats_basic',
|
||||
'scipy.stats.mstats_extras',
|
||||
'scipy.stats.mvn',
|
||||
'scipy.stats.stats',
|
||||
]
|
||||
|
||||
|
||||
def is_unexpected(name):
|
||||
"""Check if this needs to be considered."""
|
||||
if '._' in name or '.tests' in name or '.setup' in name:
|
||||
return False
|
||||
|
||||
if name in PUBLIC_MODULES:
|
||||
return False
|
||||
|
||||
if name in PRIVATE_BUT_PRESENT_MODULES:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
SKIP_LIST = [
|
||||
'scipy.conftest',
|
||||
'scipy.version',
|
||||
'scipy.special.libsf_error_state'
|
||||
]
|
||||
|
||||
|
||||
# XXX: this test does more than it says on the tin - in using `pkgutil.walk_packages`,
|
||||
# it will raise if it encounters any exceptions which are not handled by `ignore_errors`
|
||||
# while attempting to import each discovered package.
|
||||
# For now, `ignore_errors` only ignores what is necessary, but this could be expanded -
|
||||
# for example, to all errors from private modules or git subpackages - if desired.
|
||||
def test_all_modules_are_expected():
|
||||
"""
|
||||
Test that we don't add anything that looks like a new public module by
|
||||
accident. Check is based on filenames.
|
||||
"""
|
||||
|
||||
def ignore_errors(name):
|
||||
# if versions of other array libraries are installed which are incompatible
|
||||
# with the installed NumPy version, there can be errors on importing
|
||||
# `array_api_compat`. This should only raise if SciPy is configured with
|
||||
# that library as an available backend.
|
||||
backends = {'cupy': 'cupy',
|
||||
'pytorch': 'torch',
|
||||
'dask.array': 'dask.array'}
|
||||
for backend, dir_name in backends.items():
|
||||
path = f'array_api_compat.{dir_name}'
|
||||
if path in name and backend not in xp_available_backends:
|
||||
return
|
||||
raise
|
||||
|
||||
modnames = []
|
||||
|
||||
for _, modname, _ in pkgutil.walk_packages(path=scipy.__path__,
|
||||
prefix=scipy.__name__ + '.',
|
||||
onerror=ignore_errors):
|
||||
if is_unexpected(modname) and modname not in SKIP_LIST:
|
||||
# We have a name that is new. If that's on purpose, add it to
|
||||
# PUBLIC_MODULES. We don't expect to have to add anything to
|
||||
# PRIVATE_BUT_PRESENT_MODULES. Use an underscore in the name!
|
||||
modnames.append(modname)
|
||||
|
||||
if modnames:
|
||||
raise AssertionError(f'Found unexpected modules: {modnames}')
|
||||
|
||||
|
||||
# Stuff that clearly shouldn't be in the API and is detected by the next test
|
||||
# below
|
||||
SKIP_LIST_2 = [
|
||||
'scipy.char',
|
||||
'scipy.rec',
|
||||
'scipy.emath',
|
||||
'scipy.math',
|
||||
'scipy.random',
|
||||
'scipy.ctypeslib',
|
||||
'scipy.ma'
|
||||
]
|
||||
|
||||
|
||||
def test_all_modules_are_expected_2():
|
||||
"""
|
||||
Method checking all objects. The pkgutil-based method in
|
||||
`test_all_modules_are_expected` does not catch imports into a namespace,
|
||||
only filenames.
|
||||
"""
|
||||
|
||||
def find_unexpected_members(mod_name):
|
||||
members = []
|
||||
module = importlib.import_module(mod_name)
|
||||
if hasattr(module, '__all__'):
|
||||
objnames = module.__all__
|
||||
else:
|
||||
objnames = dir(module)
|
||||
|
||||
for objname in objnames:
|
||||
if not objname.startswith('_'):
|
||||
fullobjname = mod_name + '.' + objname
|
||||
if isinstance(getattr(module, objname), types.ModuleType):
|
||||
if is_unexpected(fullobjname) and fullobjname not in SKIP_LIST_2:
|
||||
members.append(fullobjname)
|
||||
|
||||
return members
|
||||
|
||||
unexpected_members = find_unexpected_members("scipy")
|
||||
for modname in PUBLIC_MODULES:
|
||||
unexpected_members.extend(find_unexpected_members(modname))
|
||||
|
||||
if unexpected_members:
|
||||
raise AssertionError("Found unexpected object(s) that look like "
|
||||
f"modules: {unexpected_members}")
|
||||
|
||||
|
||||
def test_api_importable():
|
||||
"""
|
||||
Check that all submodules listed higher up in this file can be imported
|
||||
Note that if a PRIVATE_BUT_PRESENT_MODULES entry goes missing, it may
|
||||
simply need to be removed from the list (deprecation may or may not be
|
||||
needed - apply common sense).
|
||||
"""
|
||||
def check_importable(module_name):
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
module_names = []
|
||||
for module_name in PUBLIC_MODULES:
|
||||
if not check_importable(module_name):
|
||||
module_names.append(module_name)
|
||||
|
||||
if module_names:
|
||||
raise AssertionError("Modules in the public API that cannot be "
|
||||
f"imported: {module_names}")
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.filterwarnings('always', category=DeprecationWarning)
|
||||
warnings.filterwarnings('always', category=ImportWarning)
|
||||
for module_name in PRIVATE_BUT_PRESENT_MODULES:
|
||||
if not check_importable(module_name):
|
||||
module_names.append(module_name)
|
||||
|
||||
if module_names:
|
||||
raise AssertionError("Modules that are not really public but looked "
|
||||
"public and can not be imported: "
|
||||
f"{module_names}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("module_name", "correct_module"),
|
||||
[('scipy.constants.codata', None),
|
||||
('scipy.constants.constants', None),
|
||||
('scipy.fftpack.basic', None),
|
||||
('scipy.fftpack.helper', None),
|
||||
('scipy.fftpack.pseudo_diffs', None),
|
||||
('scipy.fftpack.realtransforms', None),
|
||||
('scipy.integrate.dop', None),
|
||||
('scipy.integrate.lsoda', None),
|
||||
('scipy.integrate.odepack', None),
|
||||
('scipy.integrate.quadpack', None),
|
||||
('scipy.integrate.vode', None),
|
||||
('scipy.interpolate.fitpack', None),
|
||||
('scipy.interpolate.fitpack2', None),
|
||||
('scipy.interpolate.interpolate', None),
|
||||
('scipy.interpolate.ndgriddata', None),
|
||||
('scipy.interpolate.polyint', None),
|
||||
('scipy.interpolate.rbf', None),
|
||||
('scipy.io.harwell_boeing', None),
|
||||
('scipy.io.idl', None),
|
||||
('scipy.io.mmio', None),
|
||||
('scipy.io.netcdf', None),
|
||||
('scipy.io.arff.arffread', 'arff'),
|
||||
('scipy.io.matlab.byteordercodes', 'matlab'),
|
||||
('scipy.io.matlab.mio_utils', 'matlab'),
|
||||
('scipy.io.matlab.mio', 'matlab'),
|
||||
('scipy.io.matlab.mio4', 'matlab'),
|
||||
('scipy.io.matlab.mio5_params', 'matlab'),
|
||||
('scipy.io.matlab.mio5_utils', 'matlab'),
|
||||
('scipy.io.matlab.mio5', 'matlab'),
|
||||
('scipy.io.matlab.miobase', 'matlab'),
|
||||
('scipy.io.matlab.streams', 'matlab'),
|
||||
('scipy.linalg.basic', None),
|
||||
('scipy.linalg.decomp', None),
|
||||
('scipy.linalg.decomp_cholesky', None),
|
||||
('scipy.linalg.decomp_lu', None),
|
||||
('scipy.linalg.decomp_qr', None),
|
||||
('scipy.linalg.decomp_schur', None),
|
||||
('scipy.linalg.decomp_svd', None),
|
||||
('scipy.linalg.matfuncs', None),
|
||||
('scipy.linalg.misc', None),
|
||||
('scipy.linalg.special_matrices', None),
|
||||
('scipy.misc.common', None),
|
||||
('scipy.ndimage.filters', None),
|
||||
('scipy.ndimage.fourier', None),
|
||||
('scipy.ndimage.interpolation', None),
|
||||
('scipy.ndimage.measurements', None),
|
||||
('scipy.ndimage.morphology', None),
|
||||
('scipy.odr.models', None),
|
||||
('scipy.odr.odrpack', None),
|
||||
('scipy.optimize.cobyla', None),
|
||||
('scipy.optimize.lbfgsb', None),
|
||||
('scipy.optimize.linesearch', None),
|
||||
('scipy.optimize.minpack', None),
|
||||
('scipy.optimize.minpack2', None),
|
||||
('scipy.optimize.moduleTNC', None),
|
||||
('scipy.optimize.nonlin', None),
|
||||
('scipy.optimize.optimize', None),
|
||||
('scipy.optimize.slsqp', None),
|
||||
('scipy.optimize.tnc', None),
|
||||
('scipy.optimize.zeros', None),
|
||||
('scipy.signal.bsplines', None),
|
||||
('scipy.signal.filter_design', None),
|
||||
('scipy.signal.fir_filter_design', None),
|
||||
('scipy.signal.lti_conversion', None),
|
||||
('scipy.signal.ltisys', None),
|
||||
('scipy.signal.signaltools', None),
|
||||
('scipy.signal.spectral', None),
|
||||
('scipy.signal.waveforms', None),
|
||||
('scipy.signal.wavelets', None),
|
||||
('scipy.signal.windows.windows', 'windows'),
|
||||
('scipy.sparse.lil', None),
|
||||
('scipy.sparse.linalg.dsolve', 'linalg'),
|
||||
('scipy.sparse.linalg.eigen', 'linalg'),
|
||||
('scipy.sparse.linalg.interface', 'linalg'),
|
||||
('scipy.sparse.linalg.isolve', 'linalg'),
|
||||
('scipy.sparse.linalg.matfuncs', 'linalg'),
|
||||
('scipy.sparse.sparsetools', None),
|
||||
('scipy.sparse.spfuncs', None),
|
||||
('scipy.sparse.sputils', None),
|
||||
('scipy.spatial.ckdtree', None),
|
||||
('scipy.spatial.kdtree', None),
|
||||
('scipy.spatial.qhull', None),
|
||||
('scipy.spatial.transform.rotation', 'transform'),
|
||||
('scipy.special.add_newdocs', None),
|
||||
('scipy.special.basic', None),
|
||||
('scipy.special.orthogonal', None),
|
||||
('scipy.special.sf_error', None),
|
||||
('scipy.special.specfun', None),
|
||||
('scipy.special.spfun_stats', None),
|
||||
('scipy.stats.biasedurn', None),
|
||||
('scipy.stats.kde', None),
|
||||
('scipy.stats.morestats', None),
|
||||
('scipy.stats.mstats_basic', 'mstats'),
|
||||
('scipy.stats.mstats_extras', 'mstats'),
|
||||
('scipy.stats.mvn', None),
|
||||
('scipy.stats.stats', None)])
|
||||
def test_private_but_present_deprecation(module_name, correct_module):
|
||||
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
|
||||
# for imports from private modules
|
||||
# were misleading. Check that this is resolved.
|
||||
module = import_module(module_name)
|
||||
if correct_module is None:
|
||||
import_name = f'scipy.{module_name.split(".")[1]}'
|
||||
else:
|
||||
import_name = f'scipy.{module_name.split(".")[1]}.{correct_module}'
|
||||
|
||||
correct_import = import_module(import_name)
|
||||
|
||||
# Attributes that were formerly in `module_name` can still be imported from
|
||||
# `module_name`, albeit with a deprecation warning.
|
||||
for attr_name in module.__all__:
|
||||
if attr_name == "varmats_from_mat":
|
||||
# defer handling this case, see
|
||||
# https://github.com/scipy/scipy/issues/19223
|
||||
continue
|
||||
# ensure attribute is present where the warning is pointing
|
||||
assert getattr(correct_import, attr_name, None) is not None
|
||||
message = f"Please import `{attr_name}` from the `{import_name}`..."
|
||||
with pytest.deprecated_call(match=message):
|
||||
getattr(module, attr_name)
|
||||
|
||||
# Attributes that were not in `module_name` get an error notifying the user
|
||||
# that the attribute is not in `module_name` and that `module_name` is deprecated.
|
||||
message = f"`{module_name}` is deprecated..."
|
||||
with pytest.raises(AttributeError, match=message):
|
||||
getattr(module, "ekki")
|
||||
|
||||
|
||||
def test_misc_doccer_deprecation():
|
||||
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
|
||||
# for imports from private modules were misleading.
|
||||
# Check that this is resolved.
|
||||
# `test_private_but_present_deprecation` cannot be used since `correct_import`
|
||||
# is a different subpackage (`_lib` instead of `misc`).
|
||||
module = import_module('scipy.misc.doccer')
|
||||
correct_import = import_module('scipy._lib.doccer')
|
||||
|
||||
# Attributes that were formerly in `scipy.misc.doccer` can still be imported from
|
||||
# `scipy.misc.doccer`, albeit with a deprecation warning. The specific message
|
||||
# depends on whether the attribute is in `scipy._lib.doccer` or not.
|
||||
for attr_name in module.__all__:
|
||||
attr = getattr(correct_import, attr_name, None)
|
||||
if attr is None:
|
||||
message = f"`scipy.misc.{attr_name}` is deprecated..."
|
||||
else:
|
||||
message = f"Please import `{attr_name}` from the `scipy._lib.doccer`..."
|
||||
with pytest.deprecated_call(match=message):
|
||||
getattr(module, attr_name)
|
||||
|
||||
# Attributes that were not in `scipy.misc.doccer` get an error
|
||||
# notifying the user that the attribute is not in `scipy.misc.doccer`
|
||||
# and that `scipy.misc.doccer` is deprecated.
|
||||
message = "`scipy.misc.doccer` is deprecated..."
|
||||
with pytest.raises(AttributeError, match=message):
|
||||
getattr(module, "ekki")
|
||||
@ -0,0 +1,18 @@
|
||||
import re
|
||||
|
||||
import scipy
|
||||
from numpy.testing import assert_
|
||||
|
||||
|
||||
def test_valid_scipy_version():
|
||||
# Verify that the SciPy version is a valid one (no .post suffix or other
|
||||
# nonsense). See NumPy issue gh-6431 for an issue caused by an invalid
|
||||
# version.
|
||||
version_pattern = r"^[0-9]+\.[0-9]+\.[0-9]+(|a[0-9]|b[0-9]|rc[0-9])"
|
||||
dev_suffix = r"(\.dev0\+.+([0-9a-f]{7}|Unknown))"
|
||||
if scipy.version.release:
|
||||
res = re.match(version_pattern, scipy.__version__)
|
||||
else:
|
||||
res = re.match(version_pattern + dev_suffix, scipy.__version__)
|
||||
|
||||
assert_(res is not None, scipy.__version__)
|
||||
@ -0,0 +1,42 @@
|
||||
""" Test tmpdirs module """
|
||||
from os import getcwd
|
||||
from os.path import realpath, abspath, dirname, isfile, join as pjoin, exists
|
||||
|
||||
from scipy._lib._tmpdirs import tempdir, in_tempdir, in_dir
|
||||
|
||||
from numpy.testing import assert_, assert_equal
|
||||
|
||||
MY_PATH = abspath(__file__)
|
||||
MY_DIR = dirname(MY_PATH)
|
||||
|
||||
|
||||
def test_tempdir():
|
||||
with tempdir() as tmpdir:
|
||||
fname = pjoin(tmpdir, 'example_file.txt')
|
||||
with open(fname, "w") as fobj:
|
||||
fobj.write('a string\\n')
|
||||
assert_(not exists(tmpdir))
|
||||
|
||||
|
||||
def test_in_tempdir():
|
||||
my_cwd = getcwd()
|
||||
with in_tempdir() as tmpdir:
|
||||
with open('test.txt', "w") as f:
|
||||
f.write('some text')
|
||||
assert_(isfile('test.txt'))
|
||||
assert_(isfile(pjoin(tmpdir, 'test.txt')))
|
||||
assert_(not exists(tmpdir))
|
||||
assert_equal(getcwd(), my_cwd)
|
||||
|
||||
|
||||
def test_given_directory():
|
||||
# Test InGivenDirectory
|
||||
cwd = getcwd()
|
||||
with in_dir() as tmpdir:
|
||||
assert_equal(tmpdir, abspath(cwd))
|
||||
assert_equal(tmpdir, abspath(getcwd()))
|
||||
with in_dir(MY_DIR) as tmpdir:
|
||||
assert_equal(tmpdir, MY_DIR)
|
||||
assert_equal(realpath(MY_DIR), realpath(abspath(getcwd())))
|
||||
# We were deleting the given directory! Check not so now.
|
||||
assert_(isfile(MY_PATH))
|
||||
@ -0,0 +1,135 @@
|
||||
"""
|
||||
Tests which scan for certain occurrences in the code, they may not find
|
||||
all of these occurrences but should catch almost all. This file was adapted
|
||||
from NumPy.
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import ast
|
||||
import tokenize
|
||||
|
||||
import scipy
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class ParseCall(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.ls = []
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
self.ls.append(node.attr)
|
||||
|
||||
def visit_Name(self, node):
|
||||
self.ls.append(node.id)
|
||||
|
||||
|
||||
class FindFuncs(ast.NodeVisitor):
|
||||
def __init__(self, filename):
|
||||
super().__init__()
|
||||
self.__filename = filename
|
||||
self.bad_filters = []
|
||||
self.bad_stacklevels = []
|
||||
|
||||
def visit_Call(self, node):
|
||||
p = ParseCall()
|
||||
p.visit(node.func)
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
if p.ls[-1] == 'simplefilter' or p.ls[-1] == 'filterwarnings':
|
||||
# get first argument of the `args` node of the filter call
|
||||
match node.args[0]:
|
||||
case ast.Constant() as c:
|
||||
argtext = c.value
|
||||
case ast.JoinedStr() as js:
|
||||
# if we get an f-string, discard the templated pieces, which
|
||||
# are likely the type or specific message; we're interested
|
||||
# in the action, which is less likely to use a template
|
||||
argtext = "".join(
|
||||
x.value for x in js.values if isinstance(x, ast.Constant)
|
||||
)
|
||||
case _:
|
||||
raise ValueError("unknown ast node type")
|
||||
# check if filter is set to ignore
|
||||
if argtext == "ignore":
|
||||
self.bad_filters.append(
|
||||
f"{self.__filename}:{node.lineno}")
|
||||
|
||||
if p.ls[-1] == 'warn' and (
|
||||
len(p.ls) == 1 or p.ls[-2] == 'warnings'):
|
||||
|
||||
if self.__filename == "_lib/tests/test_warnings.py":
|
||||
# This file
|
||||
return
|
||||
|
||||
# See if stacklevel exists:
|
||||
if len(node.args) == 3:
|
||||
return
|
||||
args = {kw.arg for kw in node.keywords}
|
||||
if "stacklevel" not in args:
|
||||
self.bad_stacklevels.append(
|
||||
f"{self.__filename}:{node.lineno}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def warning_calls():
|
||||
# combined "ignore" and stacklevel error
|
||||
base = Path(scipy.__file__).parent
|
||||
|
||||
bad_filters = []
|
||||
bad_stacklevels = []
|
||||
|
||||
for path in base.rglob("*.py"):
|
||||
# use tokenize to auto-detect encoding on systems where no
|
||||
# default encoding is defined (e.g., LANG='C')
|
||||
with tokenize.open(str(path)) as file:
|
||||
tree = ast.parse(file.read(), filename=str(path))
|
||||
finder = FindFuncs(path.relative_to(base))
|
||||
finder.visit(tree)
|
||||
bad_filters.extend(finder.bad_filters)
|
||||
bad_stacklevels.extend(finder.bad_stacklevels)
|
||||
|
||||
return bad_filters, bad_stacklevels
|
||||
|
||||
|
||||
@pytest.mark.fail_slow(20)
|
||||
@pytest.mark.slow
|
||||
def test_warning_calls_filters(warning_calls):
|
||||
bad_filters, bad_stacklevels = warning_calls
|
||||
|
||||
# We try not to add filters in the code base, because those filters aren't
|
||||
# thread-safe. We aim to only filter in tests with
|
||||
# np.testing.suppress_warnings. However, in some cases it may prove
|
||||
# necessary to filter out warnings, because we can't (easily) fix the root
|
||||
# cause for them and we don't want users to see some warnings when they use
|
||||
# SciPy correctly. So we list exceptions here. Add new entries only if
|
||||
# there's a good reason.
|
||||
allowed_filters = (
|
||||
os.path.join('datasets', '_fetchers.py'),
|
||||
os.path.join('datasets', '__init__.py'),
|
||||
os.path.join('optimize', '_optimize.py'),
|
||||
os.path.join('optimize', '_constraints.py'),
|
||||
os.path.join('optimize', '_nnls.py'),
|
||||
os.path.join('signal', '_ltisys.py'),
|
||||
os.path.join('sparse', '__init__.py'), # np.matrix pending-deprecation
|
||||
os.path.join('stats', '_discrete_distns.py'), # gh-14901
|
||||
os.path.join('stats', '_continuous_distns.py'),
|
||||
os.path.join('stats', '_binned_statistic.py'), # gh-19345
|
||||
os.path.join('stats', 'tests', 'test_axis_nan_policy.py'), # gh-20694
|
||||
os.path.join('_lib', '_util.py'), # gh-19341
|
||||
os.path.join('sparse', 'linalg', '_dsolve', 'linsolve.py'), # gh-17924
|
||||
"conftest.py",
|
||||
)
|
||||
bad_filters = [item for item in bad_filters if item.split(':')[0] not in
|
||||
allowed_filters]
|
||||
|
||||
if bad_filters:
|
||||
raise AssertionError(
|
||||
"warning ignore filter should not be used, instead, use\n"
|
||||
"numpy.testing.suppress_warnings (in tests only);\n"
|
||||
"found in:\n {}".format(
|
||||
"\n ".join(bad_filters)))
|
||||
|
||||
Reference in New Issue
Block a user