some new features
This commit is contained in:
@ -0,0 +1,221 @@
|
||||
from numbers import Number
|
||||
import operator
|
||||
import os
|
||||
import threading
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scipy._lib._util import copy_if_needed
|
||||
|
||||
# good_size is exposed (and used) from this import
|
||||
from .pypocketfft import good_size, prev_good_size
|
||||
|
||||
|
||||
__all__ = ['good_size', 'prev_good_size', 'set_workers', 'get_workers']
|
||||
|
||||
_config = threading.local()
|
||||
_cpu_count = os.cpu_count()
|
||||
|
||||
|
||||
def _iterable_of_int(x, name=None):
|
||||
"""Convert ``x`` to an iterable sequence of int
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : value, or sequence of values, convertible to int
|
||||
name : str, optional
|
||||
Name of the argument being converted, only used in the error message
|
||||
|
||||
Returns
|
||||
-------
|
||||
y : ``List[int]``
|
||||
"""
|
||||
if isinstance(x, Number):
|
||||
x = (x,)
|
||||
|
||||
try:
|
||||
x = [operator.index(a) for a in x]
|
||||
except TypeError as e:
|
||||
name = name or "value"
|
||||
raise ValueError(f"{name} must be a scalar or iterable of integers") from e
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _init_nd_shape_and_axes(x, shape, axes):
|
||||
"""Handles shape and axes arguments for nd transforms"""
|
||||
noshape = shape is None
|
||||
noaxes = axes is None
|
||||
|
||||
if not noaxes:
|
||||
axes = _iterable_of_int(axes, 'axes')
|
||||
axes = [a + x.ndim if a < 0 else a for a in axes]
|
||||
|
||||
if any(a >= x.ndim or a < 0 for a in axes):
|
||||
raise ValueError("axes exceeds dimensionality of input")
|
||||
if len(set(axes)) != len(axes):
|
||||
raise ValueError("all axes must be unique")
|
||||
|
||||
if not noshape:
|
||||
shape = _iterable_of_int(shape, 'shape')
|
||||
|
||||
if axes and len(axes) != len(shape):
|
||||
raise ValueError("when given, axes and shape arguments"
|
||||
" have to be of the same length")
|
||||
if noaxes:
|
||||
if len(shape) > x.ndim:
|
||||
raise ValueError("shape requires more axes than are present")
|
||||
axes = range(x.ndim - len(shape), x.ndim)
|
||||
|
||||
shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
|
||||
elif noaxes:
|
||||
shape = list(x.shape)
|
||||
axes = range(x.ndim)
|
||||
else:
|
||||
shape = [x.shape[a] for a in axes]
|
||||
|
||||
if any(s < 1 for s in shape):
|
||||
raise ValueError(
|
||||
f"invalid number of data points ({shape}) specified")
|
||||
|
||||
return tuple(shape), list(axes)
|
||||
|
||||
|
||||
def _asfarray(x):
|
||||
"""
|
||||
Convert to array with floating or complex dtype.
|
||||
|
||||
float16 values are also promoted to float32.
|
||||
"""
|
||||
if not hasattr(x, "dtype"):
|
||||
x = np.asarray(x)
|
||||
|
||||
if x.dtype == np.float16:
|
||||
return np.asarray(x, np.float32)
|
||||
elif x.dtype.kind not in 'fc':
|
||||
return np.asarray(x, np.float64)
|
||||
|
||||
# Require native byte order
|
||||
dtype = x.dtype.newbyteorder('=')
|
||||
# Always align input
|
||||
copy = True if not x.flags['ALIGNED'] else copy_if_needed
|
||||
return np.array(x, dtype=dtype, copy=copy)
|
||||
|
||||
def _datacopied(arr, original):
|
||||
"""
|
||||
Strict check for `arr` not sharing any data with `original`,
|
||||
under the assumption that arr = asarray(original)
|
||||
"""
|
||||
if arr is original:
|
||||
return False
|
||||
if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
|
||||
return False
|
||||
return arr.base is None
|
||||
|
||||
|
||||
def _fix_shape(x, shape, axes):
|
||||
"""Internal auxiliary function for _raw_fft, _raw_fftnd."""
|
||||
must_copy = False
|
||||
|
||||
# Build an nd slice with the dimensions to be read from x
|
||||
index = [slice(None)]*x.ndim
|
||||
for n, ax in zip(shape, axes):
|
||||
if x.shape[ax] >= n:
|
||||
index[ax] = slice(0, n)
|
||||
else:
|
||||
index[ax] = slice(0, x.shape[ax])
|
||||
must_copy = True
|
||||
|
||||
index = tuple(index)
|
||||
|
||||
if not must_copy:
|
||||
return x[index], False
|
||||
|
||||
s = list(x.shape)
|
||||
for n, axis in zip(shape, axes):
|
||||
s[axis] = n
|
||||
|
||||
z = np.zeros(s, x.dtype)
|
||||
z[index] = x[index]
|
||||
return z, True
|
||||
|
||||
|
||||
def _fix_shape_1d(x, n, axis):
|
||||
if n < 1:
|
||||
raise ValueError(
|
||||
f"invalid number of data points ({n}) specified")
|
||||
|
||||
return _fix_shape(x, (n,), (axis,))
|
||||
|
||||
|
||||
_NORM_MAP = {None: 0, 'backward': 0, 'ortho': 1, 'forward': 2}
|
||||
|
||||
|
||||
def _normalization(norm, forward):
|
||||
"""Returns the pypocketfft normalization mode from the norm argument"""
|
||||
try:
|
||||
inorm = _NORM_MAP[norm]
|
||||
return inorm if forward else (2 - inorm)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f'Invalid norm value {norm!r}, should '
|
||||
'be "backward", "ortho" or "forward"') from None
|
||||
|
||||
|
||||
def _workers(workers):
|
||||
if workers is None:
|
||||
return getattr(_config, 'default_workers', 1)
|
||||
|
||||
if workers < 0:
|
||||
if workers >= -_cpu_count:
|
||||
workers += 1 + _cpu_count
|
||||
else:
|
||||
raise ValueError(f"workers value out of range; got {workers}, must not be"
|
||||
f" less than {-_cpu_count}")
|
||||
elif workers == 0:
|
||||
raise ValueError("workers must not be zero")
|
||||
|
||||
return workers
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_workers(workers):
|
||||
"""Context manager for the default number of workers used in `scipy.fft`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
workers : int
|
||||
The default number of workers to use
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from scipy import fft, signal
|
||||
>>> rng = np.random.default_rng()
|
||||
>>> x = rng.standard_normal((128, 64))
|
||||
>>> with fft.set_workers(4):
|
||||
... y = signal.fftconvolve(x, x)
|
||||
|
||||
"""
|
||||
old_workers = get_workers()
|
||||
_config.default_workers = _workers(operator.index(workers))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_config.default_workers = old_workers
|
||||
|
||||
|
||||
def get_workers():
|
||||
"""Returns the default number of workers within the current context
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy import fft
|
||||
>>> fft.get_workers()
|
||||
1
|
||||
>>> with fft.set_workers(4):
|
||||
... fft.get_workers()
|
||||
4
|
||||
"""
|
||||
return getattr(_config, 'default_workers', 1)
|
||||
Reference in New Issue
Block a user