"""Container objects for results of CmdStan run(s).""" import glob import os from typing import Any, Dict, List, Optional, Union from cmdstanpy.cmdstan_args import ( CmdStanArgs, LaplaceArgs, OptimizeArgs, PathfinderArgs, SamplerArgs, VariationalArgs, ) from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config from .gq import CmdStanGQ from .laplace import CmdStanLaplace from .mcmc import CmdStanMCMC from .metadata import InferenceMetadata from .mle import CmdStanMLE from .pathfinder import CmdStanPathfinder from .runset import RunSet from .vb import CmdStanVB __all__ = [ "RunSet", "InferenceMetadata", "CmdStanMCMC", "CmdStanMLE", "CmdStanVB", "CmdStanGQ", "CmdStanLaplace", "CmdStanPathfinder", ] def from_csv( path: Union[str, List[str], os.PathLike, None] = None, method: Optional[str] = None, ) -> Union[ CmdStanMCMC, CmdStanMLE, CmdStanVB, CmdStanPathfinder, CmdStanLaplace, None ]: """ Instantiate a CmdStan object from a the Stan CSV files from a CmdStan run. CSV files are specified from either a list of Stan CSV files or a single filepath which can be either a directory name, a Stan CSV filename, or a pathname pattern (i.e., a Python glob). The optional argument 'method' checks that the CSV files were produced by that method. Stan CSV files from CmdStan methods 'sample', 'optimize', and 'variational' result in objects of class CmdStanMCMC, CmdStanMLE, and CmdStanVB, respectively. :param path: directory path :param method: method name (optional) :return: either a CmdStanMCMC, CmdStanMLE, or CmdStanVB object """ if path is None: raise ValueError('Must specify path to Stan CSV files.') if method is not None and method not in [ 'sample', 'optimize', 'variational', 'laplace', 'pathfinder', ]: raise ValueError( 'Bad method argument {}, must be one of: ' '"sample", "optimize", "variational"'.format(method) ) csvfiles = [] if isinstance(path, list): csvfiles = path elif isinstance(path, str) and '*' in path: splits = os.path.split(path) if splits[0] is not None: if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])): raise ValueError( 'Invalid path specification, {} ' ' unknown directory: {}'.format(path, splits[0]) ) csvfiles = glob.glob(path) elif isinstance(path, (str, os.PathLike)): if os.path.exists(path) and os.path.isdir(path): for file in os.listdir(path): if os.path.splitext(file)[1] == ".csv": csvfiles.append(os.path.join(path, file)) elif os.path.exists(path): csvfiles.append(str(path)) else: raise ValueError('Invalid path specification: {}'.format(path)) else: raise ValueError('Invalid path specification: {}'.format(path)) if len(csvfiles) == 0: raise ValueError('No CSV files found in directory {}'.format(path)) for file in csvfiles: if not (os.path.exists(file) and os.path.splitext(file)[1] == ".csv"): raise ValueError( 'Bad CSV file path spec,' ' includes non-csv file: {}'.format(file) ) config_dict: Dict[str, Any] = {} try: with open(csvfiles[0], 'r') as fd: scan_config(fd, config_dict, 0) except (IOError, OSError, PermissionError) as e: raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e if 'model' not in config_dict or 'method' not in config_dict: raise ValueError("File {} is not a Stan CSV file.".format(csvfiles[0])) if method is not None and method != config_dict['method']: raise ValueError( 'Expecting Stan CSV output files from method {}, ' ' found outputs from method {}'.format( method, config_dict['method'] ) ) try: if config_dict['method'] == 'sample': chains = len(csvfiles) sampler_args = SamplerArgs( iter_sampling=config_dict['num_samples'], iter_warmup=config_dict['num_warmup'], thin=config_dict['thin'], save_warmup=config_dict['save_warmup'], ) # bugfix 425, check for fixed_params output try: check_sampler_csv( csvfiles[0], iter_sampling=config_dict['num_samples'], iter_warmup=config_dict['num_warmup'], thin=config_dict['thin'], save_warmup=config_dict['save_warmup'], ) except ValueError: try: check_sampler_csv( csvfiles[0], is_fixed_param=True, iter_sampling=config_dict['num_samples'], iter_warmup=config_dict['num_warmup'], thin=config_dict['thin'], save_warmup=config_dict['save_warmup'], ) sampler_args = SamplerArgs( iter_sampling=config_dict['num_samples'], iter_warmup=config_dict['num_warmup'], thin=config_dict['thin'], save_warmup=config_dict['save_warmup'], fixed_param=True, ) except ValueError as e: raise ValueError( 'Invalid or corrupt Stan CSV output file, ' ) from e cmdstan_args = CmdStanArgs( model_name=config_dict['model'], model_exe=config_dict['model'], chain_ids=[x + 1 for x in range(chains)], method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=chains) runset._csv_files = csvfiles for i in range(len(runset._retcodes)): runset._set_retcode(i, 0) fit = CmdStanMCMC(runset) fit.draws() return fit elif config_dict['method'] == 'optimize': if 'algorithm' not in config_dict: raise ValueError( "Cannot find optimization algorithm" " in file {}.".format(csvfiles[0]) ) optimize_args = OptimizeArgs( algorithm=config_dict['algorithm'], save_iterations=config_dict['save_iterations'], jacobian=config_dict.get('jacobian', 0), ) cmdstan_args = CmdStanArgs( model_name=config_dict['model'], model_exe=config_dict['model'], chain_ids=None, method_args=optimize_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = csvfiles for i in range(len(runset._retcodes)): runset._set_retcode(i, 0) return CmdStanMLE(runset) elif config_dict['method'] == 'variational': if 'algorithm' not in config_dict: raise ValueError( "Cannot find variational algorithm" " in file {}.".format(csvfiles[0]) ) variational_args = VariationalArgs( algorithm=config_dict['algorithm'], iter=config_dict['iter'], grad_samples=config_dict['grad_samples'], elbo_samples=config_dict['elbo_samples'], eta=config_dict['eta'], tol_rel_obj=config_dict['tol_rel_obj'], eval_elbo=config_dict['eval_elbo'], output_samples=config_dict['output_samples'], ) cmdstan_args = CmdStanArgs( model_name=config_dict['model'], model_exe=config_dict['model'], chain_ids=None, method_args=variational_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = csvfiles for i in range(len(runset._retcodes)): runset._set_retcode(i, 0) return CmdStanVB(runset) elif config_dict['method'] == 'laplace': laplace_args = LaplaceArgs( mode=config_dict['mode'], draws=config_dict['draws'], jacobian=config_dict['jacobian'], ) cmdstan_args = CmdStanArgs( model_name=config_dict['model'], model_exe=config_dict['model'], chain_ids=None, method_args=laplace_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = csvfiles for i in range(len(runset._retcodes)): runset._set_retcode(i, 0) mode: CmdStanMLE = from_csv( config_dict['mode'], method='optimize', ) # type: ignore return CmdStanLaplace(runset, mode=mode) elif config_dict['method'] == 'pathfinder': pathfinder_args = PathfinderArgs( num_draws=config_dict['num_draws'], num_paths=config_dict['num_paths'], ) cmdstan_args = CmdStanArgs( model_name=config_dict['model'], model_exe=config_dict['model'], chain_ids=None, method_args=pathfinder_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = csvfiles for i in range(len(runset._retcodes)): runset._set_retcode(i, 0) return CmdStanPathfinder(runset) else: get_logger().info( 'Unable to process CSV output files from method %s.', (config_dict['method']), ) return None except (IOError, OSError, PermissionError) as e: raise ValueError( 'An error occurred processing the CSV files:\n\t{}'.format(str(e)) ) from e