some new features
This commit is contained in:
269
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/__init__.py
Normal file
269
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/__init__.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""Container objects for results of CmdStan run(s)."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from cmdstanpy.cmdstan_args import (
|
||||
CmdStanArgs,
|
||||
LaplaceArgs,
|
||||
OptimizeArgs,
|
||||
PathfinderArgs,
|
||||
SamplerArgs,
|
||||
VariationalArgs,
|
||||
)
|
||||
from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config
|
||||
|
||||
from .gq import CmdStanGQ
|
||||
from .laplace import CmdStanLaplace
|
||||
from .mcmc import CmdStanMCMC
|
||||
from .metadata import InferenceMetadata
|
||||
from .mle import CmdStanMLE
|
||||
from .pathfinder import CmdStanPathfinder
|
||||
from .runset import RunSet
|
||||
from .vb import CmdStanVB
|
||||
|
||||
__all__ = [
|
||||
"RunSet",
|
||||
"InferenceMetadata",
|
||||
"CmdStanMCMC",
|
||||
"CmdStanMLE",
|
||||
"CmdStanVB",
|
||||
"CmdStanGQ",
|
||||
"CmdStanLaplace",
|
||||
"CmdStanPathfinder",
|
||||
]
|
||||
|
||||
|
||||
def from_csv(
|
||||
path: Union[str, List[str], os.PathLike, None] = None,
|
||||
method: Optional[str] = None,
|
||||
) -> Union[
|
||||
CmdStanMCMC, CmdStanMLE, CmdStanVB, CmdStanPathfinder, CmdStanLaplace, None
|
||||
]:
|
||||
"""
|
||||
Instantiate a CmdStan object from a the Stan CSV files from a CmdStan run.
|
||||
CSV files are specified from either a list of Stan CSV files or a single
|
||||
filepath which can be either a directory name, a Stan CSV filename, or
|
||||
a pathname pattern (i.e., a Python glob). The optional argument 'method'
|
||||
checks that the CSV files were produced by that method.
|
||||
Stan CSV files from CmdStan methods 'sample', 'optimize', and 'variational'
|
||||
result in objects of class CmdStanMCMC, CmdStanMLE, and CmdStanVB,
|
||||
respectively.
|
||||
|
||||
:param path: directory path
|
||||
:param method: method name (optional)
|
||||
|
||||
:return: either a CmdStanMCMC, CmdStanMLE, or CmdStanVB object
|
||||
"""
|
||||
if path is None:
|
||||
raise ValueError('Must specify path to Stan CSV files.')
|
||||
if method is not None and method not in [
|
||||
'sample',
|
||||
'optimize',
|
||||
'variational',
|
||||
'laplace',
|
||||
'pathfinder',
|
||||
]:
|
||||
raise ValueError(
|
||||
'Bad method argument {}, must be one of: '
|
||||
'"sample", "optimize", "variational"'.format(method)
|
||||
)
|
||||
|
||||
csvfiles = []
|
||||
if isinstance(path, list):
|
||||
csvfiles = path
|
||||
elif isinstance(path, str) and '*' in path:
|
||||
splits = os.path.split(path)
|
||||
if splits[0] is not None:
|
||||
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
|
||||
raise ValueError(
|
||||
'Invalid path specification, {} '
|
||||
' unknown directory: {}'.format(path, splits[0])
|
||||
)
|
||||
csvfiles = glob.glob(path)
|
||||
elif isinstance(path, (str, os.PathLike)):
|
||||
if os.path.exists(path) and os.path.isdir(path):
|
||||
for file in os.listdir(path):
|
||||
if os.path.splitext(file)[1] == ".csv":
|
||||
csvfiles.append(os.path.join(path, file))
|
||||
elif os.path.exists(path):
|
||||
csvfiles.append(str(path))
|
||||
else:
|
||||
raise ValueError('Invalid path specification: {}'.format(path))
|
||||
else:
|
||||
raise ValueError('Invalid path specification: {}'.format(path))
|
||||
|
||||
if len(csvfiles) == 0:
|
||||
raise ValueError('No CSV files found in directory {}'.format(path))
|
||||
for file in csvfiles:
|
||||
if not (os.path.exists(file) and os.path.splitext(file)[1] == ".csv"):
|
||||
raise ValueError(
|
||||
'Bad CSV file path spec,'
|
||||
' includes non-csv file: {}'.format(file)
|
||||
)
|
||||
|
||||
config_dict: Dict[str, Any] = {}
|
||||
try:
|
||||
with open(csvfiles[0], 'r') as fd:
|
||||
scan_config(fd, config_dict, 0)
|
||||
except (IOError, OSError, PermissionError) as e:
|
||||
raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e
|
||||
if 'model' not in config_dict or 'method' not in config_dict:
|
||||
raise ValueError("File {} is not a Stan CSV file.".format(csvfiles[0]))
|
||||
if method is not None and method != config_dict['method']:
|
||||
raise ValueError(
|
||||
'Expecting Stan CSV output files from method {}, '
|
||||
' found outputs from method {}'.format(
|
||||
method, config_dict['method']
|
||||
)
|
||||
)
|
||||
try:
|
||||
if config_dict['method'] == 'sample':
|
||||
chains = len(csvfiles)
|
||||
sampler_args = SamplerArgs(
|
||||
iter_sampling=config_dict['num_samples'],
|
||||
iter_warmup=config_dict['num_warmup'],
|
||||
thin=config_dict['thin'],
|
||||
save_warmup=config_dict['save_warmup'],
|
||||
)
|
||||
# bugfix 425, check for fixed_params output
|
||||
try:
|
||||
check_sampler_csv(
|
||||
csvfiles[0],
|
||||
iter_sampling=config_dict['num_samples'],
|
||||
iter_warmup=config_dict['num_warmup'],
|
||||
thin=config_dict['thin'],
|
||||
save_warmup=config_dict['save_warmup'],
|
||||
)
|
||||
except ValueError:
|
||||
try:
|
||||
check_sampler_csv(
|
||||
csvfiles[0],
|
||||
is_fixed_param=True,
|
||||
iter_sampling=config_dict['num_samples'],
|
||||
iter_warmup=config_dict['num_warmup'],
|
||||
thin=config_dict['thin'],
|
||||
save_warmup=config_dict['save_warmup'],
|
||||
)
|
||||
sampler_args = SamplerArgs(
|
||||
iter_sampling=config_dict['num_samples'],
|
||||
iter_warmup=config_dict['num_warmup'],
|
||||
thin=config_dict['thin'],
|
||||
save_warmup=config_dict['save_warmup'],
|
||||
fixed_param=True,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
'Invalid or corrupt Stan CSV output file, '
|
||||
) from e
|
||||
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=[x + 1 for x in range(chains)],
|
||||
method_args=sampler_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args, chains=chains)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
fit = CmdStanMCMC(runset)
|
||||
fit.draws()
|
||||
return fit
|
||||
elif config_dict['method'] == 'optimize':
|
||||
if 'algorithm' not in config_dict:
|
||||
raise ValueError(
|
||||
"Cannot find optimization algorithm"
|
||||
" in file {}.".format(csvfiles[0])
|
||||
)
|
||||
optimize_args = OptimizeArgs(
|
||||
algorithm=config_dict['algorithm'],
|
||||
save_iterations=config_dict['save_iterations'],
|
||||
jacobian=config_dict.get('jacobian', 0),
|
||||
)
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=None,
|
||||
method_args=optimize_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
return CmdStanMLE(runset)
|
||||
elif config_dict['method'] == 'variational':
|
||||
if 'algorithm' not in config_dict:
|
||||
raise ValueError(
|
||||
"Cannot find variational algorithm"
|
||||
" in file {}.".format(csvfiles[0])
|
||||
)
|
||||
variational_args = VariationalArgs(
|
||||
algorithm=config_dict['algorithm'],
|
||||
iter=config_dict['iter'],
|
||||
grad_samples=config_dict['grad_samples'],
|
||||
elbo_samples=config_dict['elbo_samples'],
|
||||
eta=config_dict['eta'],
|
||||
tol_rel_obj=config_dict['tol_rel_obj'],
|
||||
eval_elbo=config_dict['eval_elbo'],
|
||||
output_samples=config_dict['output_samples'],
|
||||
)
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=None,
|
||||
method_args=variational_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
return CmdStanVB(runset)
|
||||
elif config_dict['method'] == 'laplace':
|
||||
laplace_args = LaplaceArgs(
|
||||
mode=config_dict['mode'],
|
||||
draws=config_dict['draws'],
|
||||
jacobian=config_dict['jacobian'],
|
||||
)
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=None,
|
||||
method_args=laplace_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
mode: CmdStanMLE = from_csv(
|
||||
config_dict['mode'],
|
||||
method='optimize',
|
||||
) # type: ignore
|
||||
return CmdStanLaplace(runset, mode=mode)
|
||||
elif config_dict['method'] == 'pathfinder':
|
||||
pathfinder_args = PathfinderArgs(
|
||||
num_draws=config_dict['num_draws'],
|
||||
num_paths=config_dict['num_paths'],
|
||||
)
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=None,
|
||||
method_args=pathfinder_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
return CmdStanPathfinder(runset)
|
||||
else:
|
||||
get_logger().info(
|
||||
'Unable to process CSV output files from method %s.',
|
||||
(config_dict['method']),
|
||||
)
|
||||
return None
|
||||
except (IOError, OSError, PermissionError) as e:
|
||||
raise ValueError(
|
||||
'An error occurred processing the CSV files:\n\t{}'.format(str(e))
|
||||
) from e
|
||||
Reference in New Issue
Block a user