reconnect moved files to git repo
This commit is contained in:
252
venv/lib/python3.11/site-packages/stanio/reshape.py
Normal file
252
venv/lib/python3.11/site-packages/stanio/reshape.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""
|
||||
Classes and functions for reshaping Stan output.
|
||||
|
||||
Especially with the addition of tuples, Stan writes
|
||||
flat arrays of data with a rich internal structure.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class VariableType(Enum):
|
||||
SCALAR = 1 # real or integer
|
||||
COMPLEX = 2 # complex number - requires striding
|
||||
TUPLE = 3 # tuples - require recursive handling
|
||||
|
||||
|
||||
@dataclass
|
||||
class Variable:
|
||||
"""
|
||||
This class represents a single output variable of a Stan model.
|
||||
|
||||
It contains information about the name, dimensions, and type of the
|
||||
variable, as well as the indices of where that variable is located in
|
||||
the flattened output array Stan models write.
|
||||
|
||||
Generally, this class should not be instantiated directly, but rather
|
||||
created by the :func:`parse_header()` function.
|
||||
"""
|
||||
|
||||
# name of the parameter as given in stan. For nested parameters, this is a dummy name
|
||||
name: str
|
||||
# where to start (resp. end) reading from the flattened array.
|
||||
# For arrays with nested parameters, this will be for the first element
|
||||
# and is relative to the start of the parent
|
||||
start_idx: int
|
||||
end_idx: int
|
||||
# rectangular dimensions of the parameter (e.g. (2, 3) for a 2x3 matrix)
|
||||
# For nested parameters, this will be the dimensions of the outermost array.
|
||||
dimensions: Tuple[int, ...]
|
||||
# type of the parameter
|
||||
type: VariableType
|
||||
# list of nested parameters
|
||||
contents: List["Variable"]
|
||||
|
||||
def dtype(self, top: bool = True) -> np.dtype:
|
||||
if self.type == VariableType.TUPLE:
|
||||
elts = [
|
||||
(str(i + 1), param.dtype(top=False))
|
||||
for i, param in enumerate(self.contents)
|
||||
]
|
||||
dtype = np.dtype(elts)
|
||||
elif self.type == VariableType.SCALAR:
|
||||
dtype = np.float64
|
||||
elif self.type == VariableType.COMPLEX:
|
||||
dtype = np.complex128
|
||||
|
||||
if top:
|
||||
return dtype
|
||||
else:
|
||||
return np.dtype((dtype, self.dimensions))
|
||||
|
||||
def columns(self) -> Iterable[int]:
|
||||
return range(self.start_idx, self.end_idx)
|
||||
|
||||
def num_elts(self) -> int:
|
||||
return prod(self.dimensions)
|
||||
|
||||
def elt_size(self) -> int:
|
||||
return self.end_idx - self.start_idx
|
||||
|
||||
# total size is elt_size * num_elts
|
||||
|
||||
def _extract_helper(self, src: np.ndarray, offset: int = 0) -> np.ndarray:
|
||||
start = self.start_idx + offset
|
||||
end = self.end_idx + offset
|
||||
if self.type == VariableType.SCALAR:
|
||||
return src[..., start:end].reshape(-1, *self.dimensions, order="F")
|
||||
elif self.type == VariableType.COMPLEX:
|
||||
ret = src[..., start:end].reshape(-1, 2, *self.dimensions, order="F")
|
||||
ret = ret[:, ::2] + 1j * ret[:, 1::2]
|
||||
return ret.squeeze().reshape(-1, *self.dimensions, order="F")
|
||||
elif self.type == VariableType.TUPLE:
|
||||
out: np.ndarray = np.empty(
|
||||
(prod(src.shape[:-1]), prod(self.dimensions)), dtype=object
|
||||
)
|
||||
for idx in range(self.num_elts()):
|
||||
off = idx * self.elt_size() // self.num_elts()
|
||||
elts = [
|
||||
param._extract_helper(src, offset=start + off)
|
||||
for param in self.contents
|
||||
]
|
||||
for i in range(elts[0].shape[0]):
|
||||
out[i, idx] = tuple(elt[i] for elt in elts)
|
||||
return out.reshape(-1, *self.dimensions, order="F")
|
||||
|
||||
def extract_reshape(self, src: np.ndarray, object: bool = True) -> npt.NDArray[Any]:
|
||||
"""
|
||||
Given an array where the final dimension is the flattened output of a
|
||||
Stan model, (e.g. one row of a Stan CSV file), extract the variable
|
||||
and reshape it to the correct type and dimensions.
|
||||
|
||||
This will most likely result in copies of the data being made if
|
||||
the variable is not a scalar.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
src : np.ndarray
|
||||
The array to extract from.
|
||||
|
||||
Indicies besides the final dimension are preserved
|
||||
in the output.
|
||||
|
||||
object : bool
|
||||
If True, the output of tuple types will be an object array,
|
||||
otherwise it will use custom dtypes to represent tuples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
npt.NDArray[Any]
|
||||
The extracted variable, reshaped to the correct dimensions.
|
||||
If the variable is a tuple, this will be an object array,
|
||||
otherwise it will have a dtype of either float64 or complex128.
|
||||
"""
|
||||
out = self._extract_helper(src)
|
||||
if not object:
|
||||
out = out.astype(self.dtype())
|
||||
if src.ndim > 1:
|
||||
out = out.reshape(*src.shape[:-1], *self.dimensions, order="F")
|
||||
else:
|
||||
out = out.squeeze(axis=0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _munge_first_tuple(tup: str) -> str:
|
||||
return "dummy_" + tup.split(":", 1)[1]
|
||||
|
||||
|
||||
def _get_base_name(param: str) -> str:
|
||||
return param.split(".")[0].split(":")[0]
|
||||
|
||||
|
||||
def _from_header(header: str) -> List[Variable]:
|
||||
# appending __dummy ensures one extra iteration in the later loop
|
||||
header = header.strip() + ",__dummy"
|
||||
entries = header.split(",")
|
||||
params = []
|
||||
start_idx = 0
|
||||
name = _get_base_name(entries[0])
|
||||
for i in range(0, len(entries) - 1):
|
||||
entry = entries[i]
|
||||
next_name = _get_base_name(entries[i + 1])
|
||||
|
||||
if next_name != name:
|
||||
if ":" not in entry:
|
||||
dims = entry.split(".")[1:]
|
||||
if ".real" in entry or ".imag" in entry:
|
||||
type = VariableType.COMPLEX
|
||||
dims = dims[:-1]
|
||||
else:
|
||||
type = VariableType.SCALAR
|
||||
params.append(
|
||||
Variable(
|
||||
name=name,
|
||||
start_idx=start_idx,
|
||||
end_idx=i + 1,
|
||||
dimensions=tuple(map(int, dims)),
|
||||
type=type,
|
||||
contents=[],
|
||||
)
|
||||
)
|
||||
else:
|
||||
dims = entry.split(":")[0].split(".")[1:]
|
||||
munged_header = ",".join(
|
||||
dict.fromkeys(map(_munge_first_tuple, entries[start_idx : i + 1]))
|
||||
)
|
||||
|
||||
params.append(
|
||||
Variable(
|
||||
name=name,
|
||||
start_idx=start_idx,
|
||||
end_idx=i + 1,
|
||||
dimensions=tuple(map(int, dims)),
|
||||
type=VariableType.TUPLE,
|
||||
contents=_from_header(munged_header),
|
||||
)
|
||||
)
|
||||
|
||||
start_idx = i + 1
|
||||
name = next_name
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def parse_header(header: str) -> Dict[str, Variable]:
|
||||
"""
|
||||
Given a comma-separated list of names of Stan outputs, like
|
||||
that from the header row of a CSV file, parse it into a dictionary of
|
||||
:class:`Variable` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
header : str
|
||||
Comma separated list of Stan variables, including index information.
|
||||
For example, an ``array[2] real foo` would be represented as
|
||||
``foo.1,foo.2``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Variable]
|
||||
A dictionary mapping the base name of each variable to a :class:`Variable`.
|
||||
"""
|
||||
return {param.name: param for param in _from_header(header)}
|
||||
|
||||
|
||||
def stan_variables(
|
||||
parameters: Dict[str, Variable],
|
||||
source: npt.NDArray[np.float64],
|
||||
*,
|
||||
object: bool = True,
|
||||
) -> Dict[str, npt.NDArray[Any]]:
|
||||
"""
|
||||
Given a dictionary of :class:`Variable` objects and a source array,
|
||||
extract the variables from the source array and reshape them to the
|
||||
correct dimensions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parameters : Dict[str, Variable]
|
||||
A dictionary of :class:`Variable` objects,
|
||||
like that returned by :func:`parse_header()`.
|
||||
source : npt.NDArray[np.float64]
|
||||
The array to extract from.
|
||||
object : bool
|
||||
If True, the output of tuple types will be an object array,
|
||||
otherwise it will use custom dtypes to represent tuples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, npt.NDArray[Any]]
|
||||
A dictionary mapping the base name of each variable to the extracted
|
||||
and reshaped data.
|
||||
"""
|
||||
return {
|
||||
param.name: param.extract_reshape(source, object=object)
|
||||
for param in parameters.values()
|
||||
}
|
||||
Reference in New Issue
Block a user