253 lines
8.3 KiB
Python
253 lines
8.3 KiB
Python
"""
|
|
Classes and functions for reshaping Stan output.
|
|
|
|
Especially with the addition of tuples, Stan writes
|
|
flat arrays of data with a rich internal structure.
|
|
"""
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from math import prod
|
|
from typing import Any, Dict, Iterable, List, Tuple
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
|
|
class VariableType(Enum):
|
|
SCALAR = 1 # real or integer
|
|
COMPLEX = 2 # complex number - requires striding
|
|
TUPLE = 3 # tuples - require recursive handling
|
|
|
|
|
|
@dataclass
|
|
class Variable:
|
|
"""
|
|
This class represents a single output variable of a Stan model.
|
|
|
|
It contains information about the name, dimensions, and type of the
|
|
variable, as well as the indices of where that variable is located in
|
|
the flattened output array Stan models write.
|
|
|
|
Generally, this class should not be instantiated directly, but rather
|
|
created by the :func:`parse_header()` function.
|
|
"""
|
|
|
|
# name of the parameter as given in stan. For nested parameters, this is a dummy name
|
|
name: str
|
|
# where to start (resp. end) reading from the flattened array.
|
|
# For arrays with nested parameters, this will be for the first element
|
|
# and is relative to the start of the parent
|
|
start_idx: int
|
|
end_idx: int
|
|
# rectangular dimensions of the parameter (e.g. (2, 3) for a 2x3 matrix)
|
|
# For nested parameters, this will be the dimensions of the outermost array.
|
|
dimensions: Tuple[int, ...]
|
|
# type of the parameter
|
|
type: VariableType
|
|
# list of nested parameters
|
|
contents: List["Variable"]
|
|
|
|
def dtype(self, top: bool = True) -> np.dtype:
|
|
if self.type == VariableType.TUPLE:
|
|
elts = [
|
|
(str(i + 1), param.dtype(top=False))
|
|
for i, param in enumerate(self.contents)
|
|
]
|
|
dtype = np.dtype(elts)
|
|
elif self.type == VariableType.SCALAR:
|
|
dtype = np.float64
|
|
elif self.type == VariableType.COMPLEX:
|
|
dtype = np.complex128
|
|
|
|
if top:
|
|
return dtype
|
|
else:
|
|
return np.dtype((dtype, self.dimensions))
|
|
|
|
def columns(self) -> Iterable[int]:
|
|
return range(self.start_idx, self.end_idx)
|
|
|
|
def num_elts(self) -> int:
|
|
return prod(self.dimensions)
|
|
|
|
def elt_size(self) -> int:
|
|
return self.end_idx - self.start_idx
|
|
|
|
# total size is elt_size * num_elts
|
|
|
|
def _extract_helper(self, src: np.ndarray, offset: int = 0) -> np.ndarray:
|
|
start = self.start_idx + offset
|
|
end = self.end_idx + offset
|
|
if self.type == VariableType.SCALAR:
|
|
return src[..., start:end].reshape(-1, *self.dimensions, order="F")
|
|
elif self.type == VariableType.COMPLEX:
|
|
ret = src[..., start:end].reshape(-1, 2, *self.dimensions, order="F")
|
|
ret = ret[:, ::2] + 1j * ret[:, 1::2]
|
|
return ret.squeeze().reshape(-1, *self.dimensions, order="F")
|
|
elif self.type == VariableType.TUPLE:
|
|
out: np.ndarray = np.empty(
|
|
(prod(src.shape[:-1]), prod(self.dimensions)), dtype=object
|
|
)
|
|
for idx in range(self.num_elts()):
|
|
off = idx * self.elt_size() // self.num_elts()
|
|
elts = [
|
|
param._extract_helper(src, offset=start + off)
|
|
for param in self.contents
|
|
]
|
|
for i in range(elts[0].shape[0]):
|
|
out[i, idx] = tuple(elt[i] for elt in elts)
|
|
return out.reshape(-1, *self.dimensions, order="F")
|
|
|
|
def extract_reshape(self, src: np.ndarray, object: bool = True) -> npt.NDArray[Any]:
|
|
"""
|
|
Given an array where the final dimension is the flattened output of a
|
|
Stan model, (e.g. one row of a Stan CSV file), extract the variable
|
|
and reshape it to the correct type and dimensions.
|
|
|
|
This will most likely result in copies of the data being made if
|
|
the variable is not a scalar.
|
|
|
|
Parameters
|
|
----------
|
|
src : np.ndarray
|
|
The array to extract from.
|
|
|
|
Indicies besides the final dimension are preserved
|
|
in the output.
|
|
|
|
object : bool
|
|
If True, the output of tuple types will be an object array,
|
|
otherwise it will use custom dtypes to represent tuples.
|
|
|
|
Returns
|
|
-------
|
|
npt.NDArray[Any]
|
|
The extracted variable, reshaped to the correct dimensions.
|
|
If the variable is a tuple, this will be an object array,
|
|
otherwise it will have a dtype of either float64 or complex128.
|
|
"""
|
|
out = self._extract_helper(src)
|
|
if not object:
|
|
out = out.astype(self.dtype())
|
|
if src.ndim > 1:
|
|
out = out.reshape(*src.shape[:-1], *self.dimensions, order="F")
|
|
else:
|
|
out = out.squeeze(axis=0)
|
|
|
|
return out
|
|
|
|
|
|
def _munge_first_tuple(tup: str) -> str:
|
|
return "dummy_" + tup.split(":", 1)[1]
|
|
|
|
|
|
def _get_base_name(param: str) -> str:
|
|
return param.split(".")[0].split(":")[0]
|
|
|
|
|
|
def _from_header(header: str) -> List[Variable]:
|
|
# appending __dummy ensures one extra iteration in the later loop
|
|
header = header.strip() + ",__dummy"
|
|
entries = header.split(",")
|
|
params = []
|
|
start_idx = 0
|
|
name = _get_base_name(entries[0])
|
|
for i in range(0, len(entries) - 1):
|
|
entry = entries[i]
|
|
next_name = _get_base_name(entries[i + 1])
|
|
|
|
if next_name != name:
|
|
if ":" not in entry:
|
|
dims = entry.split(".")[1:]
|
|
if ".real" in entry or ".imag" in entry:
|
|
type = VariableType.COMPLEX
|
|
dims = dims[:-1]
|
|
else:
|
|
type = VariableType.SCALAR
|
|
params.append(
|
|
Variable(
|
|
name=name,
|
|
start_idx=start_idx,
|
|
end_idx=i + 1,
|
|
dimensions=tuple(map(int, dims)),
|
|
type=type,
|
|
contents=[],
|
|
)
|
|
)
|
|
else:
|
|
dims = entry.split(":")[0].split(".")[1:]
|
|
munged_header = ",".join(
|
|
dict.fromkeys(map(_munge_first_tuple, entries[start_idx : i + 1]))
|
|
)
|
|
|
|
params.append(
|
|
Variable(
|
|
name=name,
|
|
start_idx=start_idx,
|
|
end_idx=i + 1,
|
|
dimensions=tuple(map(int, dims)),
|
|
type=VariableType.TUPLE,
|
|
contents=_from_header(munged_header),
|
|
)
|
|
)
|
|
|
|
start_idx = i + 1
|
|
name = next_name
|
|
|
|
return params
|
|
|
|
|
|
def parse_header(header: str) -> Dict[str, Variable]:
|
|
"""
|
|
Given a comma-separated list of names of Stan outputs, like
|
|
that from the header row of a CSV file, parse it into a dictionary of
|
|
:class:`Variable` objects.
|
|
|
|
Parameters
|
|
----------
|
|
header : str
|
|
Comma separated list of Stan variables, including index information.
|
|
For example, an ``array[2] real foo` would be represented as
|
|
``foo.1,foo.2``.
|
|
|
|
Returns
|
|
-------
|
|
Dict[str, Variable]
|
|
A dictionary mapping the base name of each variable to a :class:`Variable`.
|
|
"""
|
|
return {param.name: param for param in _from_header(header)}
|
|
|
|
|
|
def stan_variables(
|
|
parameters: Dict[str, Variable],
|
|
source: npt.NDArray[np.float64],
|
|
*,
|
|
object: bool = True,
|
|
) -> Dict[str, npt.NDArray[Any]]:
|
|
"""
|
|
Given a dictionary of :class:`Variable` objects and a source array,
|
|
extract the variables from the source array and reshape them to the
|
|
correct dimensions.
|
|
|
|
Parameters
|
|
----------
|
|
parameters : Dict[str, Variable]
|
|
A dictionary of :class:`Variable` objects,
|
|
like that returned by :func:`parse_header()`.
|
|
source : npt.NDArray[np.float64]
|
|
The array to extract from.
|
|
object : bool
|
|
If True, the output of tuple types will be an object array,
|
|
otherwise it will use custom dtypes to represent tuples.
|
|
|
|
Returns
|
|
-------
|
|
Dict[str, npt.NDArray[Any]]
|
|
A dictionary mapping the base name of each variable to the extracted
|
|
and reshaped data.
|
|
"""
|
|
return {
|
|
param.name: param.extract_reshape(source, object=object)
|
|
for param in parameters.values()
|
|
}
|