Files
Time-Series-Analysis/.venv/lib/python3.12/site-packages/cmdstanpy/utils/data_munging.py
2025-07-30 18:53:50 +03:00

45 lines
1.2 KiB
Python

"""
Common functions for reshaping numpy arrays
"""
from typing import Hashable, MutableMapping, Tuple
import numpy as np
import stanio
def flatten_chains(draws_array: np.ndarray) -> np.ndarray:
"""
Flatten a 3D array of draws X chains X variable into 2D array
where all chains are concatenated into a single column.
:param draws_array: 3D array of draws
"""
if len(draws_array.shape) != 3:
raise ValueError(
'Expecting 3D array, found array with {} dims'.format(
len(draws_array.shape)
)
)
num_rows = draws_array.shape[0] * draws_array.shape[1]
num_cols = draws_array.shape[2]
return draws_array.reshape((num_rows, num_cols), order='F')
def build_xarray_data(
data: MutableMapping[Hashable, Tuple[Tuple[str, ...], np.ndarray]],
var: stanio.Variable,
drawset: np.ndarray,
) -> None:
"""
Adds Stan variable name, labels, and values to a dictionary
that will be used to construct an xarray DataSet.
"""
var_dims: Tuple[str, ...] = ('draw', 'chain')
var_dims += tuple(f"{var.name}_dim_{i}" for i in range(len(var.dimensions)))
data[var.name] = (
var_dims,
var.extract_reshape(drawset),
)