735 lines
25 KiB
Python
735 lines
25 KiB
Python
"""
|
|
Container for the result of running the
|
|
generate quantities (GQ) method
|
|
"""
|
|
|
|
from collections import Counter
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Generic,
|
|
Hashable,
|
|
List,
|
|
MutableMapping,
|
|
NoReturn,
|
|
Optional,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
overload,
|
|
)
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
try:
|
|
import xarray as xr
|
|
|
|
XARRAY_INSTALLED = True
|
|
except ImportError:
|
|
XARRAY_INSTALLED = False
|
|
|
|
|
|
from cmdstanpy.cmdstan_args import Method
|
|
from cmdstanpy.utils import build_xarray_data, flatten_chains, get_logger
|
|
from cmdstanpy.utils.stancsv import scan_generic_csv
|
|
|
|
from .mcmc import CmdStanMCMC
|
|
from .metadata import InferenceMetadata
|
|
from .mle import CmdStanMLE
|
|
from .runset import RunSet
|
|
from .vb import CmdStanVB
|
|
|
|
Fit = TypeVar('Fit', CmdStanMCMC, CmdStanMLE, CmdStanVB)
|
|
|
|
|
|
class CmdStanGQ(Generic[Fit]):
|
|
"""
|
|
Container for outputs from CmdStan generate_quantities run.
|
|
Created by :meth:`CmdStanModel.generate_quantities`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
runset: RunSet,
|
|
previous_fit: Fit,
|
|
) -> None:
|
|
"""Initialize object."""
|
|
if not runset.method == Method.GENERATE_QUANTITIES:
|
|
raise ValueError(
|
|
'Wrong runset method, expecting generate_quantities runset, '
|
|
'found method {}'.format(runset.method)
|
|
)
|
|
self.runset = runset
|
|
|
|
self.previous_fit: Fit = previous_fit
|
|
|
|
self._draws: np.ndarray = np.array(())
|
|
config = self._validate_csv_files()
|
|
self._metadata = InferenceMetadata(config)
|
|
|
|
def __repr__(self) -> str:
|
|
repr = 'CmdStanGQ: model={} chains={}{}'.format(
|
|
self.runset.model,
|
|
self.chains,
|
|
self.runset._args.method_args.compose(0, cmd=[]),
|
|
)
|
|
repr = '{}\n csv_files:\n\t{}\n output_files:\n\t{}'.format(
|
|
repr,
|
|
'\n\t'.join(self.runset.csv_files),
|
|
'\n\t'.join(self.runset.stdout_files),
|
|
)
|
|
return repr
|
|
|
|
def __getattr__(self, attr: str) -> np.ndarray:
|
|
"""Synonymous with ``fit.stan_variable(attr)"""
|
|
if attr.startswith("_"):
|
|
raise AttributeError(f"Unknown variable name {attr}")
|
|
try:
|
|
return self.stan_variable(attr)
|
|
except ValueError as e:
|
|
# pylint: disable=raise-missing-from
|
|
raise AttributeError(*e.args)
|
|
|
|
def __getstate__(self) -> dict:
|
|
# This function returns the mapping of objects to serialize with pickle.
|
|
# See https://docs.python.org/3/library/pickle.html#object.__getstate__
|
|
# for details. We call _assemble_generated_quantities to ensure
|
|
# the data are loaded prior to serialization.
|
|
self._assemble_generated_quantities()
|
|
return self.__dict__
|
|
|
|
def _validate_csv_files(self) -> Dict[str, Any]:
|
|
"""
|
|
Checks that Stan CSV output files for all chains are consistent
|
|
and returns dict containing config and column names.
|
|
|
|
Raises exception when inconsistencies detected.
|
|
"""
|
|
dzero = {}
|
|
for i in range(self.chains):
|
|
if i == 0:
|
|
dzero = scan_generic_csv(
|
|
path=self.runset.csv_files[i],
|
|
)
|
|
else:
|
|
drest = scan_generic_csv(
|
|
path=self.runset.csv_files[i],
|
|
)
|
|
for key in dzero:
|
|
if (
|
|
key
|
|
not in [
|
|
'id',
|
|
'fitted_params',
|
|
'diagnostic_file',
|
|
'metric_file',
|
|
'profile_file',
|
|
'init',
|
|
'seed',
|
|
'start_datetime',
|
|
]
|
|
and dzero[key] != drest[key]
|
|
):
|
|
raise ValueError(
|
|
'CmdStan config mismatch in Stan CSV file {}: '
|
|
'arg {} is {}, expected {}'.format(
|
|
self.runset.csv_files[i],
|
|
key,
|
|
dzero[key],
|
|
drest[key],
|
|
)
|
|
)
|
|
return dzero
|
|
|
|
@property
|
|
def chains(self) -> int:
|
|
"""Number of chains."""
|
|
return self.runset.chains
|
|
|
|
@property
|
|
def chain_ids(self) -> List[int]:
|
|
"""Chain ids."""
|
|
return self.runset.chain_ids
|
|
|
|
@property
|
|
def column_names(self) -> Tuple[str, ...]:
|
|
"""
|
|
Names of generated quantities of interest.
|
|
"""
|
|
return self._metadata.cmdstan_config['column_names'] # type: ignore
|
|
|
|
@property
|
|
def metadata(self) -> InferenceMetadata:
|
|
"""
|
|
Returns object which contains CmdStan configuration as well as
|
|
information about the names and structure of the inference method
|
|
and model output variables.
|
|
"""
|
|
return self._metadata
|
|
|
|
def draws(
|
|
self,
|
|
*,
|
|
inc_warmup: bool = False,
|
|
inc_iterations: bool = False,
|
|
concat_chains: bool = False,
|
|
inc_sample: bool = False,
|
|
) -> np.ndarray:
|
|
"""
|
|
Returns a numpy.ndarray over the generated quantities draws from
|
|
all chains which is stored column major so that the values
|
|
for a parameter are contiguous in memory, likewise all draws from
|
|
a chain are contiguous. By default, returns a 3D array arranged
|
|
(draws, chains, columns); parameter ``concat_chains=True`` will
|
|
return a 2D array where all chains are flattened into a single column,
|
|
preserving chain order, so that given M chains of N draws,
|
|
the first N draws are from chain 1, ..., and the the last N draws
|
|
are from chain M.
|
|
|
|
:param inc_warmup: When ``True`` and the warmup draws are present in
|
|
the output, i.e., the sampler was run with ``save_warmup=True``,
|
|
then the warmup draws are included. Default value is ``False``.
|
|
|
|
:param concat_chains: When ``True`` return a 2D array flattening all
|
|
all draws from all chains. Default value is ``False``.
|
|
|
|
:param inc_sample: When ``True`` include all columns in the previous_fit
|
|
draws array as well, excepting columns for variables already present
|
|
in the generated quantities drawset. Default value is ``False``.
|
|
|
|
See Also
|
|
--------
|
|
CmdStanGQ.draws_pd
|
|
CmdStanGQ.draws_xr
|
|
CmdStanMCMC.draws
|
|
"""
|
|
self._assemble_generated_quantities()
|
|
inc_warmup |= inc_iterations
|
|
if inc_warmup:
|
|
if (
|
|
isinstance(self.previous_fit, CmdStanMCMC)
|
|
and not self.previous_fit._save_warmup
|
|
):
|
|
get_logger().warning(
|
|
"Sample doesn't contain draws from warmup iterations,"
|
|
' rerun sampler with "save_warmup=True".'
|
|
)
|
|
elif (
|
|
isinstance(self.previous_fit, CmdStanMLE)
|
|
and not self.previous_fit._save_iterations
|
|
):
|
|
get_logger().warning(
|
|
"MLE doesn't contain draws from pre-convergence iterations,"
|
|
' rerun optimization with "save_iterations=True".'
|
|
)
|
|
elif isinstance(self.previous_fit, CmdStanVB):
|
|
get_logger().warning(
|
|
"Variational fit doesn't make sense with argument "
|
|
'"inc_warmup=True"'
|
|
)
|
|
|
|
if inc_sample:
|
|
cols_1 = self.previous_fit.column_names
|
|
cols_2 = self.column_names
|
|
dups = [
|
|
item
|
|
for item, count in Counter(cols_1 + cols_2).items()
|
|
if count > 1
|
|
]
|
|
drop_cols: List[int] = []
|
|
for dup in dups:
|
|
drop_cols.extend(
|
|
self.previous_fit._metadata.stan_vars[dup].columns()
|
|
)
|
|
|
|
start_idx, _ = self._draws_start(inc_warmup)
|
|
previous_draws = self._previous_draws(True)
|
|
if concat_chains and inc_sample:
|
|
return flatten_chains(
|
|
np.dstack(
|
|
(
|
|
np.delete(previous_draws, drop_cols, axis=1),
|
|
self._draws,
|
|
)
|
|
)[start_idx:, :, :]
|
|
)
|
|
if concat_chains:
|
|
return flatten_chains(self._draws[start_idx:, :, :])
|
|
if inc_sample:
|
|
return np.dstack(
|
|
(
|
|
np.delete(previous_draws, drop_cols, axis=1),
|
|
self._draws,
|
|
)
|
|
)[start_idx:, :, :]
|
|
return self._draws[start_idx:, :, :]
|
|
|
|
def draws_pd(
|
|
self,
|
|
vars: Union[List[str], str, None] = None,
|
|
inc_warmup: bool = False,
|
|
inc_sample: bool = False,
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Returns the generated quantities draws as a pandas DataFrame.
|
|
Flattens all chains into single column. Container variables
|
|
(array, vector, matrix) will span multiple columns, one column
|
|
per element. E.g. variable 'matrix[2,2] foo' spans 4 columns:
|
|
'foo[1,1], ... foo[2,2]'.
|
|
|
|
:param vars: optional list of variable names.
|
|
|
|
:param inc_warmup: When ``True`` and the warmup draws are present in
|
|
the output, i.e., the sampler was run with ``save_warmup=True``,
|
|
then the warmup draws are included. Default value is ``False``.
|
|
|
|
See Also
|
|
--------
|
|
CmdStanGQ.draws
|
|
CmdStanGQ.draws_xr
|
|
CmdStanMCMC.draws_pd
|
|
"""
|
|
if vars is not None:
|
|
if isinstance(vars, str):
|
|
vars_list = [vars]
|
|
else:
|
|
vars_list = vars
|
|
|
|
vars_list = list(dict.fromkeys(vars_list))
|
|
|
|
if inc_warmup:
|
|
if (
|
|
isinstance(self.previous_fit, CmdStanMCMC)
|
|
and not self.previous_fit._save_warmup
|
|
):
|
|
get_logger().warning(
|
|
"Sample doesn't contain draws from warmup iterations,"
|
|
' rerun sampler with "save_warmup=True".'
|
|
)
|
|
elif (
|
|
isinstance(self.previous_fit, CmdStanMLE)
|
|
and not self.previous_fit._save_iterations
|
|
):
|
|
get_logger().warning(
|
|
"MLE doesn't contain draws from pre-convergence iterations,"
|
|
' rerun optimization with "save_iterations=True".'
|
|
)
|
|
elif isinstance(self.previous_fit, CmdStanVB):
|
|
get_logger().warning(
|
|
"Variational fit doesn't make sense with argument "
|
|
'"inc_warmup=True"'
|
|
)
|
|
|
|
self._assemble_generated_quantities()
|
|
|
|
all_columns = ['chain__', 'iter__', 'draw__'] + list(self.column_names)
|
|
|
|
gq_cols: List[str] = []
|
|
mcmc_vars: List[str] = []
|
|
if vars is not None:
|
|
for var in vars_list:
|
|
if var in self._metadata.stan_vars:
|
|
info = self._metadata.stan_vars[var]
|
|
gq_cols.extend(
|
|
self.column_names[info.start_idx : info.end_idx]
|
|
)
|
|
elif (
|
|
inc_sample and var in self.previous_fit._metadata.stan_vars
|
|
):
|
|
info = self.previous_fit._metadata.stan_vars[var]
|
|
mcmc_vars.extend(
|
|
self.previous_fit.column_names[
|
|
info.start_idx : info.end_idx
|
|
]
|
|
)
|
|
elif var in ['chain__', 'iter__', 'draw__']:
|
|
gq_cols.append(var)
|
|
else:
|
|
raise ValueError('Unknown variable: {}'.format(var))
|
|
else:
|
|
gq_cols = all_columns
|
|
vars_list = gq_cols
|
|
|
|
previous_draws_pd = self._previous_draws_pd(mcmc_vars, inc_warmup)
|
|
|
|
draws = self.draws(inc_warmup=inc_warmup)
|
|
# add long-form columns for chain, iteration, draw
|
|
n_draws, n_chains, _ = draws.shape
|
|
chains_col = (
|
|
np.repeat(np.arange(1, n_chains + 1), n_draws)
|
|
.reshape(1, n_chains, n_draws)
|
|
.T
|
|
)
|
|
iter_col = (
|
|
np.tile(np.arange(1, n_draws + 1), n_chains)
|
|
.reshape(1, n_chains, n_draws)
|
|
.T
|
|
)
|
|
draw_col = (
|
|
np.arange(1, (n_draws * n_chains) + 1)
|
|
.reshape(1, n_chains, n_draws)
|
|
.T
|
|
)
|
|
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
|
|
|
|
draws_pd = pd.DataFrame(
|
|
data=flatten_chains(draws),
|
|
columns=all_columns,
|
|
)
|
|
|
|
if inc_sample and mcmc_vars:
|
|
if gq_cols:
|
|
return pd.concat(
|
|
[
|
|
previous_draws_pd,
|
|
draws_pd[gq_cols],
|
|
],
|
|
axis='columns',
|
|
)[vars_list]
|
|
else:
|
|
return previous_draws_pd
|
|
elif inc_sample and vars is None:
|
|
cols_1 = list(previous_draws_pd.columns)
|
|
cols_2 = list(draws_pd.columns)
|
|
dups = [
|
|
item
|
|
for item, count in Counter(cols_1 + cols_2).items()
|
|
if count > 1
|
|
]
|
|
return pd.concat(
|
|
[
|
|
previous_draws_pd.drop(columns=dups).reset_index(drop=True),
|
|
draws_pd,
|
|
],
|
|
axis=1,
|
|
)
|
|
elif gq_cols:
|
|
return draws_pd[gq_cols]
|
|
|
|
return draws_pd
|
|
|
|
@overload
|
|
def draws_xr(
|
|
self: Union["CmdStanGQ[CmdStanMLE]", "CmdStanGQ[CmdStanVB]"],
|
|
vars: Union[str, List[str], None] = None,
|
|
inc_warmup: bool = False,
|
|
inc_sample: bool = False,
|
|
) -> NoReturn:
|
|
...
|
|
|
|
@overload
|
|
def draws_xr(
|
|
self: "CmdStanGQ[CmdStanMCMC]",
|
|
vars: Union[str, List[str], None] = None,
|
|
inc_warmup: bool = False,
|
|
inc_sample: bool = False,
|
|
) -> "xr.Dataset":
|
|
...
|
|
|
|
def draws_xr(
|
|
self,
|
|
vars: Union[str, List[str], None] = None,
|
|
inc_warmup: bool = False,
|
|
inc_sample: bool = False,
|
|
) -> "xr.Dataset":
|
|
"""
|
|
Returns the generated quantities draws as a xarray Dataset.
|
|
|
|
This method can only be called when the underlying fit was made
|
|
through sampling, it cannot be used on MLE or VB outputs.
|
|
|
|
:param vars: optional list of variable names.
|
|
|
|
:param inc_warmup: When ``True`` and the warmup draws are present in
|
|
the MCMC sample, then the warmup draws are included.
|
|
Default value is ``False``.
|
|
|
|
See Also
|
|
--------
|
|
CmdStanGQ.draws
|
|
CmdStanGQ.draws_pd
|
|
CmdStanMCMC.draws_xr
|
|
"""
|
|
if not XARRAY_INSTALLED:
|
|
raise RuntimeError(
|
|
'Package "xarray" is not installed, cannot produce draws array.'
|
|
)
|
|
if not isinstance(self.previous_fit, CmdStanMCMC):
|
|
raise RuntimeError(
|
|
'Method "draws_xr" is only available when '
|
|
'original fit is done via Sampling.'
|
|
)
|
|
mcmc_vars_list = []
|
|
dup_vars = []
|
|
if vars is not None:
|
|
if isinstance(vars, str):
|
|
vars_list = [vars]
|
|
else:
|
|
vars_list = vars
|
|
for var in vars_list:
|
|
if var not in self._metadata.stan_vars:
|
|
if inc_sample and (
|
|
var in self.previous_fit._metadata.stan_vars
|
|
):
|
|
mcmc_vars_list.append(var)
|
|
dup_vars.append(var)
|
|
else:
|
|
raise ValueError('Unknown variable: {}'.format(var))
|
|
else:
|
|
vars_list = list(self._metadata.stan_vars.keys())
|
|
if inc_sample:
|
|
for var in self.previous_fit._metadata.stan_vars.keys():
|
|
if var not in vars_list and var not in mcmc_vars_list:
|
|
mcmc_vars_list.append(var)
|
|
for var in dup_vars:
|
|
vars_list.remove(var)
|
|
|
|
self._assemble_generated_quantities()
|
|
|
|
num_draws = self.previous_fit.num_draws_sampling
|
|
sample_config = self.previous_fit._metadata.cmdstan_config
|
|
attrs: MutableMapping[Hashable, Any] = {
|
|
"stan_version": f"{sample_config['stan_version_major']}."
|
|
f"{sample_config['stan_version_minor']}."
|
|
f"{sample_config['stan_version_patch']}",
|
|
"model": sample_config["model"],
|
|
"num_draws_sampling": num_draws,
|
|
}
|
|
if inc_warmup and sample_config['save_warmup']:
|
|
num_draws += self.previous_fit.num_draws_warmup
|
|
attrs["num_draws_warmup"] = self.previous_fit.num_draws_warmup
|
|
|
|
data: MutableMapping[Hashable, Any] = {}
|
|
coordinates: MutableMapping[Hashable, Any] = {
|
|
"chain": self.chain_ids,
|
|
"draw": np.arange(num_draws),
|
|
}
|
|
|
|
for var in vars_list:
|
|
build_xarray_data(
|
|
data,
|
|
self._metadata.stan_vars[var],
|
|
self.draws(inc_warmup=inc_warmup),
|
|
)
|
|
if inc_sample:
|
|
for var in mcmc_vars_list:
|
|
build_xarray_data(
|
|
data,
|
|
self.previous_fit._metadata.stan_vars[var],
|
|
self.previous_fit.draws(inc_warmup=inc_warmup),
|
|
)
|
|
|
|
return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
|
|
'chain', 'draw', ...
|
|
)
|
|
|
|
def stan_variable(self, var: str, **kwargs: bool) -> np.ndarray:
|
|
"""
|
|
Return a numpy.ndarray which contains the set of draws
|
|
for the named Stan program variable. Flattens the chains,
|
|
leaving the draws in chain order. The first array dimension,
|
|
corresponds to number of draws in the sample.
|
|
The remaining dimensions correspond to
|
|
the shape of the Stan program variable.
|
|
|
|
Underlyingly draws are in chain order, i.e., for a sample with
|
|
N chains of M draws each, the first M array elements are from chain 1,
|
|
the next M are from chain 2, and the last M elements are from chain N.
|
|
|
|
* If the variable is a scalar variable, the return array has shape
|
|
( draws * chains, 1).
|
|
* If the variable is a vector, the return array has shape
|
|
( draws * chains, len(vector))
|
|
* If the variable is a matrix, the return array has shape
|
|
( draws * chains, size(dim 1), size(dim 2) )
|
|
* If the variable is an array with N dimensions, the return array
|
|
has shape ( draws * chains, size(dim 1), ..., size(dim N))
|
|
|
|
For example, if the Stan program variable ``theta`` is a 3x3 matrix,
|
|
and the sample consists of 4 chains with 1000 post-warmup draws,
|
|
this function will return a numpy.ndarray with shape (4000,3,3).
|
|
|
|
This functionaltiy is also available via a shortcut using ``.`` -
|
|
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
|
|
|
|
:param var: variable name
|
|
|
|
:param kwargs: Additional keyword arguments are passed to the underlying
|
|
fit's ``stan_variable`` method if the variable is not a generated
|
|
quantity.
|
|
|
|
See Also
|
|
--------
|
|
CmdStanGQ.stan_variables
|
|
CmdStanMCMC.stan_variable
|
|
CmdStanMLE.stan_variable
|
|
CmdStanPathfinder.stan_variable
|
|
CmdStanVB.stan_variable
|
|
CmdStanLaplace.stan_variable
|
|
"""
|
|
model_var_names = self.previous_fit._metadata.stan_vars.keys()
|
|
gq_var_names = self._metadata.stan_vars.keys()
|
|
if not (var in model_var_names or var in gq_var_names):
|
|
raise ValueError(
|
|
f'Unknown variable name: {var}\n'
|
|
'Available variables are '
|
|
+ ", ".join(model_var_names | gq_var_names)
|
|
)
|
|
if var not in gq_var_names:
|
|
# TODO(2.0) atleast1d may not be needed
|
|
return np.atleast_1d( # type: ignore
|
|
self.previous_fit.stan_variable(var, **kwargs)
|
|
)
|
|
|
|
# is gq variable
|
|
self._assemble_generated_quantities()
|
|
|
|
draw1, _ = self._draws_start(
|
|
inc_warmup=kwargs.get('inc_warmup', False)
|
|
or kwargs.get('inc_iterations', False)
|
|
)
|
|
draws = flatten_chains(self._draws[draw1:])
|
|
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(draws)
|
|
return out
|
|
|
|
def stan_variables(self, **kwargs: bool) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Return a dictionary mapping Stan program variables names
|
|
to the corresponding numpy.ndarray containing the inferred values.
|
|
|
|
:param kwargs: Additional keyword arguments are passed to the underlying
|
|
fit's ``stan_variable`` method if the variable is not a generated
|
|
quantity.
|
|
|
|
See Also
|
|
--------
|
|
CmdStanGQ.stan_variable
|
|
CmdStanMCMC.stan_variables
|
|
CmdStanMLE.stan_variables
|
|
CmdStanPathfinder.stan_variables
|
|
CmdStanVB.stan_variables
|
|
CmdStanLaplace.stan_variables
|
|
"""
|
|
result = {}
|
|
sample_var_names = self.previous_fit._metadata.stan_vars.keys()
|
|
gq_var_names = self._metadata.stan_vars.keys()
|
|
for name in gq_var_names:
|
|
result[name] = self.stan_variable(name, **kwargs)
|
|
for name in sample_var_names:
|
|
if name not in gq_var_names:
|
|
result[name] = self.stan_variable(name, **kwargs)
|
|
return result
|
|
|
|
def _assemble_generated_quantities(self) -> None:
|
|
if self._draws.shape != (0,):
|
|
return
|
|
# use numpy loadtxt
|
|
_, num_draws = self._draws_start(inc_warmup=True)
|
|
|
|
gq_sample: np.ndarray = np.empty(
|
|
(num_draws, self.chains, len(self.column_names)),
|
|
dtype=float,
|
|
order='F',
|
|
)
|
|
for chain in range(self.chains):
|
|
with open(self.runset.csv_files[chain], 'r') as fd:
|
|
lines = (line for line in fd if not line.startswith('#'))
|
|
gq_sample[:, chain, :] = np.loadtxt(
|
|
lines, dtype=np.ndarray, ndmin=2, skiprows=1, delimiter=','
|
|
)
|
|
self._draws = gq_sample
|
|
|
|
def _draws_start(self, inc_warmup: bool) -> Tuple[int, int]:
|
|
draw1 = 0
|
|
p_fit = self.previous_fit
|
|
if isinstance(p_fit, CmdStanMCMC):
|
|
num_draws = p_fit.num_draws_sampling
|
|
if p_fit._save_warmup:
|
|
if inc_warmup:
|
|
num_draws += p_fit.num_draws_warmup
|
|
else:
|
|
draw1 = p_fit.num_draws_warmup
|
|
|
|
elif isinstance(p_fit, CmdStanMLE):
|
|
num_draws = 1
|
|
if p_fit._save_iterations:
|
|
opt_iters = len(p_fit.optimized_iterations_np) # type: ignore
|
|
if inc_warmup:
|
|
num_draws = opt_iters
|
|
else:
|
|
draw1 = opt_iters - 1
|
|
else: # CmdStanVB:
|
|
draw1 = 1 # skip mean
|
|
num_draws = p_fit.variational_sample.shape[0]
|
|
if inc_warmup:
|
|
num_draws += 1
|
|
|
|
return draw1, num_draws
|
|
|
|
def _previous_draws(self, inc_warmup: bool) -> np.ndarray:
|
|
"""
|
|
Extract the draws from self.previous_fit.
|
|
Return is always 3-d
|
|
"""
|
|
p_fit = self.previous_fit
|
|
if isinstance(p_fit, CmdStanMCMC):
|
|
return p_fit.draws(inc_warmup=inc_warmup)
|
|
elif isinstance(p_fit, CmdStanMLE):
|
|
if inc_warmup and p_fit._save_iterations:
|
|
return p_fit.optimized_iterations_np[:, None] # type: ignore
|
|
|
|
return np.atleast_2d( # type: ignore
|
|
p_fit.optimized_params_np,
|
|
)[:, None]
|
|
else: # CmdStanVB:
|
|
if inc_warmup:
|
|
return np.vstack(
|
|
[p_fit.variational_params_np, p_fit.variational_sample]
|
|
)[:, None]
|
|
return p_fit.variational_sample[:, None]
|
|
|
|
def _previous_draws_pd(
|
|
self, vars: List[str], inc_warmup: bool
|
|
) -> pd.DataFrame:
|
|
if vars:
|
|
sel: Union[List[str], slice] = vars
|
|
else:
|
|
sel = slice(None, None)
|
|
|
|
p_fit = self.previous_fit
|
|
if isinstance(p_fit, CmdStanMCMC):
|
|
return p_fit.draws_pd(vars or None, inc_warmup=inc_warmup)
|
|
|
|
elif isinstance(p_fit, CmdStanMLE):
|
|
if inc_warmup and p_fit._save_iterations:
|
|
return p_fit.optimized_iterations_pd[sel] # type: ignore
|
|
else:
|
|
return p_fit.optimized_params_pd[sel]
|
|
else: # CmdStanVB:
|
|
return p_fit.variational_sample_pd[sel]
|
|
|
|
def save_csvfiles(self, dir: Optional[str] = None) -> None:
|
|
"""
|
|
Move output CSV files to specified directory. If files were
|
|
written to the temporary session directory, clean filename.
|
|
E.g., save 'bernoulli-201912081451-1-5nm6as7u.csv' as
|
|
'bernoulli-201912081451-1.csv'.
|
|
|
|
:param dir: directory path
|
|
|
|
See Also
|
|
--------
|
|
stanfit.RunSet.save_csvfiles
|
|
cmdstanpy.from_csv
|
|
"""
|
|
self.runset.save_csvfiles(dir)
|
|
|
|
# TODO(2.0): remove
|
|
@property
|
|
def mcmc_sample(self) -> Union[CmdStanMCMC, CmdStanMLE, CmdStanVB]:
|
|
get_logger().warning(
|
|
"Property `mcmc_sample` is deprecated, use `previous_fit` instead"
|
|
)
|
|
return self.previous_fit
|