270 lines
10 KiB
Python
270 lines
10 KiB
Python
"""Container objects for results of CmdStan run(s)."""
|
|
|
|
import glob
|
|
import os
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from cmdstanpy.cmdstan_args import (
|
|
CmdStanArgs,
|
|
LaplaceArgs,
|
|
OptimizeArgs,
|
|
PathfinderArgs,
|
|
SamplerArgs,
|
|
VariationalArgs,
|
|
)
|
|
from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config
|
|
|
|
from .gq import CmdStanGQ
|
|
from .laplace import CmdStanLaplace
|
|
from .mcmc import CmdStanMCMC
|
|
from .metadata import InferenceMetadata
|
|
from .mle import CmdStanMLE
|
|
from .pathfinder import CmdStanPathfinder
|
|
from .runset import RunSet
|
|
from .vb import CmdStanVB
|
|
|
|
__all__ = [
|
|
"RunSet",
|
|
"InferenceMetadata",
|
|
"CmdStanMCMC",
|
|
"CmdStanMLE",
|
|
"CmdStanVB",
|
|
"CmdStanGQ",
|
|
"CmdStanLaplace",
|
|
"CmdStanPathfinder",
|
|
]
|
|
|
|
|
|
def from_csv(
|
|
path: Union[str, List[str], os.PathLike, None] = None,
|
|
method: Optional[str] = None,
|
|
) -> Union[
|
|
CmdStanMCMC, CmdStanMLE, CmdStanVB, CmdStanPathfinder, CmdStanLaplace, None
|
|
]:
|
|
"""
|
|
Instantiate a CmdStan object from a the Stan CSV files from a CmdStan run.
|
|
CSV files are specified from either a list of Stan CSV files or a single
|
|
filepath which can be either a directory name, a Stan CSV filename, or
|
|
a pathname pattern (i.e., a Python glob). The optional argument 'method'
|
|
checks that the CSV files were produced by that method.
|
|
Stan CSV files from CmdStan methods 'sample', 'optimize', and 'variational'
|
|
result in objects of class CmdStanMCMC, CmdStanMLE, and CmdStanVB,
|
|
respectively.
|
|
|
|
:param path: directory path
|
|
:param method: method name (optional)
|
|
|
|
:return: either a CmdStanMCMC, CmdStanMLE, or CmdStanVB object
|
|
"""
|
|
if path is None:
|
|
raise ValueError('Must specify path to Stan CSV files.')
|
|
if method is not None and method not in [
|
|
'sample',
|
|
'optimize',
|
|
'variational',
|
|
'laplace',
|
|
'pathfinder',
|
|
]:
|
|
raise ValueError(
|
|
'Bad method argument {}, must be one of: '
|
|
'"sample", "optimize", "variational"'.format(method)
|
|
)
|
|
|
|
csvfiles = []
|
|
if isinstance(path, list):
|
|
csvfiles = path
|
|
elif isinstance(path, str) and '*' in path:
|
|
splits = os.path.split(path)
|
|
if splits[0] is not None:
|
|
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
|
|
raise ValueError(
|
|
'Invalid path specification, {} '
|
|
' unknown directory: {}'.format(path, splits[0])
|
|
)
|
|
csvfiles = glob.glob(path)
|
|
elif isinstance(path, (str, os.PathLike)):
|
|
if os.path.exists(path) and os.path.isdir(path):
|
|
for file in os.listdir(path):
|
|
if os.path.splitext(file)[1] == ".csv":
|
|
csvfiles.append(os.path.join(path, file))
|
|
elif os.path.exists(path):
|
|
csvfiles.append(str(path))
|
|
else:
|
|
raise ValueError('Invalid path specification: {}'.format(path))
|
|
else:
|
|
raise ValueError('Invalid path specification: {}'.format(path))
|
|
|
|
if len(csvfiles) == 0:
|
|
raise ValueError('No CSV files found in directory {}'.format(path))
|
|
for file in csvfiles:
|
|
if not (os.path.exists(file) and os.path.splitext(file)[1] == ".csv"):
|
|
raise ValueError(
|
|
'Bad CSV file path spec,'
|
|
' includes non-csv file: {}'.format(file)
|
|
)
|
|
|
|
config_dict: Dict[str, Any] = {}
|
|
try:
|
|
with open(csvfiles[0], 'r') as fd:
|
|
scan_config(fd, config_dict, 0)
|
|
except (IOError, OSError, PermissionError) as e:
|
|
raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e
|
|
if 'model' not in config_dict or 'method' not in config_dict:
|
|
raise ValueError("File {} is not a Stan CSV file.".format(csvfiles[0]))
|
|
if method is not None and method != config_dict['method']:
|
|
raise ValueError(
|
|
'Expecting Stan CSV output files from method {}, '
|
|
' found outputs from method {}'.format(
|
|
method, config_dict['method']
|
|
)
|
|
)
|
|
try:
|
|
if config_dict['method'] == 'sample':
|
|
chains = len(csvfiles)
|
|
sampler_args = SamplerArgs(
|
|
iter_sampling=config_dict['num_samples'],
|
|
iter_warmup=config_dict['num_warmup'],
|
|
thin=config_dict['thin'],
|
|
save_warmup=config_dict['save_warmup'],
|
|
)
|
|
# bugfix 425, check for fixed_params output
|
|
try:
|
|
check_sampler_csv(
|
|
csvfiles[0],
|
|
iter_sampling=config_dict['num_samples'],
|
|
iter_warmup=config_dict['num_warmup'],
|
|
thin=config_dict['thin'],
|
|
save_warmup=config_dict['save_warmup'],
|
|
)
|
|
except ValueError:
|
|
try:
|
|
check_sampler_csv(
|
|
csvfiles[0],
|
|
is_fixed_param=True,
|
|
iter_sampling=config_dict['num_samples'],
|
|
iter_warmup=config_dict['num_warmup'],
|
|
thin=config_dict['thin'],
|
|
save_warmup=config_dict['save_warmup'],
|
|
)
|
|
sampler_args = SamplerArgs(
|
|
iter_sampling=config_dict['num_samples'],
|
|
iter_warmup=config_dict['num_warmup'],
|
|
thin=config_dict['thin'],
|
|
save_warmup=config_dict['save_warmup'],
|
|
fixed_param=True,
|
|
)
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
'Invalid or corrupt Stan CSV output file, '
|
|
) from e
|
|
|
|
cmdstan_args = CmdStanArgs(
|
|
model_name=config_dict['model'],
|
|
model_exe=config_dict['model'],
|
|
chain_ids=[x + 1 for x in range(chains)],
|
|
method_args=sampler_args,
|
|
)
|
|
runset = RunSet(args=cmdstan_args, chains=chains)
|
|
runset._csv_files = csvfiles
|
|
for i in range(len(runset._retcodes)):
|
|
runset._set_retcode(i, 0)
|
|
fit = CmdStanMCMC(runset)
|
|
fit.draws()
|
|
return fit
|
|
elif config_dict['method'] == 'optimize':
|
|
if 'algorithm' not in config_dict:
|
|
raise ValueError(
|
|
"Cannot find optimization algorithm"
|
|
" in file {}.".format(csvfiles[0])
|
|
)
|
|
optimize_args = OptimizeArgs(
|
|
algorithm=config_dict['algorithm'],
|
|
save_iterations=config_dict['save_iterations'],
|
|
jacobian=config_dict.get('jacobian', 0),
|
|
)
|
|
cmdstan_args = CmdStanArgs(
|
|
model_name=config_dict['model'],
|
|
model_exe=config_dict['model'],
|
|
chain_ids=None,
|
|
method_args=optimize_args,
|
|
)
|
|
runset = RunSet(args=cmdstan_args)
|
|
runset._csv_files = csvfiles
|
|
for i in range(len(runset._retcodes)):
|
|
runset._set_retcode(i, 0)
|
|
return CmdStanMLE(runset)
|
|
elif config_dict['method'] == 'variational':
|
|
if 'algorithm' not in config_dict:
|
|
raise ValueError(
|
|
"Cannot find variational algorithm"
|
|
" in file {}.".format(csvfiles[0])
|
|
)
|
|
variational_args = VariationalArgs(
|
|
algorithm=config_dict['algorithm'],
|
|
iter=config_dict['iter'],
|
|
grad_samples=config_dict['grad_samples'],
|
|
elbo_samples=config_dict['elbo_samples'],
|
|
eta=config_dict['eta'],
|
|
tol_rel_obj=config_dict['tol_rel_obj'],
|
|
eval_elbo=config_dict['eval_elbo'],
|
|
output_samples=config_dict['output_samples'],
|
|
)
|
|
cmdstan_args = CmdStanArgs(
|
|
model_name=config_dict['model'],
|
|
model_exe=config_dict['model'],
|
|
chain_ids=None,
|
|
method_args=variational_args,
|
|
)
|
|
runset = RunSet(args=cmdstan_args)
|
|
runset._csv_files = csvfiles
|
|
for i in range(len(runset._retcodes)):
|
|
runset._set_retcode(i, 0)
|
|
return CmdStanVB(runset)
|
|
elif config_dict['method'] == 'laplace':
|
|
laplace_args = LaplaceArgs(
|
|
mode=config_dict['mode'],
|
|
draws=config_dict['draws'],
|
|
jacobian=config_dict['jacobian'],
|
|
)
|
|
cmdstan_args = CmdStanArgs(
|
|
model_name=config_dict['model'],
|
|
model_exe=config_dict['model'],
|
|
chain_ids=None,
|
|
method_args=laplace_args,
|
|
)
|
|
runset = RunSet(args=cmdstan_args)
|
|
runset._csv_files = csvfiles
|
|
for i in range(len(runset._retcodes)):
|
|
runset._set_retcode(i, 0)
|
|
mode: CmdStanMLE = from_csv(
|
|
config_dict['mode'],
|
|
method='optimize',
|
|
) # type: ignore
|
|
return CmdStanLaplace(runset, mode=mode)
|
|
elif config_dict['method'] == 'pathfinder':
|
|
pathfinder_args = PathfinderArgs(
|
|
num_draws=config_dict['num_draws'],
|
|
num_paths=config_dict['num_paths'],
|
|
)
|
|
cmdstan_args = CmdStanArgs(
|
|
model_name=config_dict['model'],
|
|
model_exe=config_dict['model'],
|
|
chain_ids=None,
|
|
method_args=pathfinder_args,
|
|
)
|
|
runset = RunSet(args=cmdstan_args)
|
|
runset._csv_files = csvfiles
|
|
for i in range(len(runset._retcodes)):
|
|
runset._set_retcode(i, 0)
|
|
return CmdStanPathfinder(runset)
|
|
else:
|
|
get_logger().info(
|
|
'Unable to process CSV output files from method %s.',
|
|
(config_dict['method']),
|
|
)
|
|
return None
|
|
except (IOError, OSError, PermissionError) as e:
|
|
raise ValueError(
|
|
'An error occurred processing the CSV files:\n\t{}'.format(str(e))
|
|
) from e
|