""" Common functions for reshaping numpy arrays """ from typing import Hashable, MutableMapping, Tuple import numpy as np import stanio def flatten_chains(draws_array: np.ndarray) -> np.ndarray: """ Flatten a 3D array of draws X chains X variable into 2D array where all chains are concatenated into a single column. :param draws_array: 3D array of draws """ if len(draws_array.shape) != 3: raise ValueError( 'Expecting 3D array, found array with {} dims'.format( len(draws_array.shape) ) ) num_rows = draws_array.shape[0] * draws_array.shape[1] num_cols = draws_array.shape[2] return draws_array.reshape((num_rows, num_cols), order='F') def build_xarray_data( data: MutableMapping[Hashable, Tuple[Tuple[str, ...], np.ndarray]], var: stanio.Variable, drawset: np.ndarray, ) -> None: """ Adds Stan variable name, labels, and values to a dictionary that will be used to construct an xarray DataSet. """ var_dims: Tuple[str, ...] = ('draw', 'chain') var_dims += tuple(f"{var.name}_dim_{i}" for i in range(len(var.dimensions))) data[var.name] = ( var_dims, var.extract_reshape(drawset), )