some new features
This commit is contained in:
@ -0,0 +1,44 @@
|
||||
"""
|
||||
Common functions for reshaping numpy arrays
|
||||
"""
|
||||
from typing import Hashable, MutableMapping, Tuple
|
||||
|
||||
import numpy as np
|
||||
import stanio
|
||||
|
||||
|
||||
def flatten_chains(draws_array: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Flatten a 3D array of draws X chains X variable into 2D array
|
||||
where all chains are concatenated into a single column.
|
||||
|
||||
:param draws_array: 3D array of draws
|
||||
"""
|
||||
if len(draws_array.shape) != 3:
|
||||
raise ValueError(
|
||||
'Expecting 3D array, found array with {} dims'.format(
|
||||
len(draws_array.shape)
|
||||
)
|
||||
)
|
||||
|
||||
num_rows = draws_array.shape[0] * draws_array.shape[1]
|
||||
num_cols = draws_array.shape[2]
|
||||
return draws_array.reshape((num_rows, num_cols), order='F')
|
||||
|
||||
|
||||
def build_xarray_data(
|
||||
data: MutableMapping[Hashable, Tuple[Tuple[str, ...], np.ndarray]],
|
||||
var: stanio.Variable,
|
||||
drawset: np.ndarray,
|
||||
) -> None:
|
||||
"""
|
||||
Adds Stan variable name, labels, and values to a dictionary
|
||||
that will be used to construct an xarray DataSet.
|
||||
"""
|
||||
var_dims: Tuple[str, ...] = ('draw', 'chain')
|
||||
var_dims += tuple(f"{var.name}_dim_{i}" for i in range(len(var.dimensions)))
|
||||
|
||||
data[var.name] = (
|
||||
var_dims,
|
||||
var.extract_reshape(drawset),
|
||||
)
|
||||
Reference in New Issue
Block a user