some new features
This commit is contained in:
486
.venv/lib/python3.12/site-packages/cmdstanpy/utils/stancsv.py
Normal file
486
.venv/lib/python3.12/site-packages/cmdstanpy/utils/stancsv.py
Normal file
@ -0,0 +1,486 @@
|
||||
"""
|
||||
Utility functions for reading the Stan CSV format
|
||||
"""
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from typing import Any, Dict, List, MutableMapping, Optional, TextIO, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP
|
||||
|
||||
|
||||
def check_sampler_csv(
|
||||
path: str,
|
||||
is_fixed_param: bool = False,
|
||||
iter_sampling: Optional[int] = None,
|
||||
iter_warmup: Optional[int] = None,
|
||||
save_warmup: bool = False,
|
||||
thin: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Capture essential config, shape from stan_csv file."""
|
||||
meta = scan_sampler_csv(path, is_fixed_param)
|
||||
if thin is None:
|
||||
thin = _CMDSTAN_THIN
|
||||
elif thin > _CMDSTAN_THIN:
|
||||
if 'thin' not in meta:
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, '
|
||||
'config error, expected thin = {}'.format(path, thin)
|
||||
)
|
||||
if meta['thin'] != thin:
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, '
|
||||
'config error, expected thin = {}, found {}'.format(
|
||||
path, thin, meta['thin']
|
||||
)
|
||||
)
|
||||
draws_sampling = iter_sampling
|
||||
if draws_sampling is None:
|
||||
draws_sampling = _CMDSTAN_SAMPLING
|
||||
draws_warmup = iter_warmup
|
||||
if draws_warmup is None:
|
||||
draws_warmup = _CMDSTAN_WARMUP
|
||||
draws_warmup = int(math.ceil(draws_warmup / thin))
|
||||
draws_sampling = int(math.ceil(draws_sampling / thin))
|
||||
if meta['draws_sampling'] != draws_sampling:
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, expected {} draws, found {}'.format(
|
||||
path, draws_sampling, meta['draws_sampling']
|
||||
)
|
||||
)
|
||||
if save_warmup:
|
||||
if not ('save_warmup' in meta and meta['save_warmup'] in (1, 'true')):
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, '
|
||||
'config error, expected save_warmup = 1'.format(path)
|
||||
)
|
||||
if meta['draws_warmup'] != draws_warmup:
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, '
|
||||
'expected {} warmup draws, found {}'.format(
|
||||
path, draws_warmup, meta['draws_warmup']
|
||||
)
|
||||
)
|
||||
return meta
|
||||
|
||||
|
||||
def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
|
||||
"""Process sampler stan_csv output file line by line."""
|
||||
dict: Dict[str, Any] = {}
|
||||
lineno = 0
|
||||
with open(path, 'r') as fd:
|
||||
try:
|
||||
lineno = scan_config(fd, dict, lineno)
|
||||
lineno = scan_column_names(fd, dict, lineno)
|
||||
if not is_fixed_param:
|
||||
lineno = scan_warmup_iters(fd, dict, lineno)
|
||||
lineno = scan_hmc_params(fd, dict, lineno)
|
||||
lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param)
|
||||
except ValueError as e:
|
||||
raise ValueError("Error in reading csv file: " + path) from e
|
||||
return dict
|
||||
|
||||
|
||||
def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
|
||||
"""Process optimizer stan_csv output file line by line."""
|
||||
dict: Dict[str, Any] = {}
|
||||
lineno = 0
|
||||
# scan to find config, header, num saved iters
|
||||
with open(path, 'r') as fd:
|
||||
lineno = scan_config(fd, dict, lineno)
|
||||
lineno = scan_column_names(fd, dict, lineno)
|
||||
iters = 0
|
||||
for line in fd:
|
||||
iters += 1
|
||||
if save_iters:
|
||||
all_iters: np.ndarray = np.empty(
|
||||
(iters, len(dict['column_names'])), dtype=float, order='F'
|
||||
)
|
||||
# rescan to capture estimates
|
||||
with open(path, 'r') as fd:
|
||||
for i in range(lineno):
|
||||
fd.readline()
|
||||
for i in range(iters):
|
||||
line = fd.readline().strip()
|
||||
if len(line) < 1:
|
||||
raise ValueError(
|
||||
'cannot parse CSV file {}, error at line {}'.format(
|
||||
path, lineno + i
|
||||
)
|
||||
)
|
||||
xs = line.split(',')
|
||||
if save_iters:
|
||||
all_iters[i, :] = [float(x) for x in xs]
|
||||
if i == iters - 1:
|
||||
mle: np.ndarray = np.array(xs, dtype=float)
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
dict['mle'] = mle
|
||||
if save_iters:
|
||||
dict['all_iters'] = all_iters
|
||||
return dict
|
||||
|
||||
|
||||
def scan_generic_csv(path: str) -> Dict[str, Any]:
|
||||
"""Process laplace stan_csv output file line by line."""
|
||||
dict: Dict[str, Any] = {}
|
||||
lineno = 0
|
||||
with open(path, 'r') as fd:
|
||||
lineno = scan_config(fd, dict, lineno)
|
||||
lineno = scan_column_names(fd, dict, lineno)
|
||||
return dict
|
||||
|
||||
|
||||
def scan_variational_csv(path: str) -> Dict[str, Any]:
|
||||
"""Process advi stan_csv output file line by line."""
|
||||
dict: Dict[str, Any] = {}
|
||||
lineno = 0
|
||||
with open(path, 'r') as fd:
|
||||
lineno = scan_config(fd, dict, lineno)
|
||||
lineno = scan_column_names(fd, dict, lineno)
|
||||
line = fd.readline().lstrip(' #\t').rstrip()
|
||||
lineno += 1
|
||||
if line.startswith('Stepsize adaptation complete.'):
|
||||
line = fd.readline().lstrip(' #\t\n')
|
||||
lineno += 1
|
||||
if not line.startswith('eta'):
|
||||
raise ValueError(
|
||||
'line {}: expecting eta, found:\n\t "{}"'.format(
|
||||
lineno, line
|
||||
)
|
||||
)
|
||||
_, eta = line.split('=')
|
||||
dict['eta'] = float(eta)
|
||||
line = fd.readline().lstrip(' #\t\n')
|
||||
lineno += 1
|
||||
xs = line.split(',')
|
||||
variational_mean = [float(x) for x in xs]
|
||||
dict['variational_mean'] = np.array(variational_mean)
|
||||
dict['variational_sample'] = pd.read_csv(
|
||||
path,
|
||||
comment='#',
|
||||
skiprows=lineno,
|
||||
header=None,
|
||||
float_precision='high',
|
||||
).to_numpy()
|
||||
return dict
|
||||
|
||||
|
||||
def scan_config(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int:
|
||||
"""
|
||||
Scan initial stan_csv file comments lines and
|
||||
save non-default configuration information to config_dict.
|
||||
"""
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
while len(line) > 0 and line.startswith('#'):
|
||||
lineno += 1
|
||||
if line.endswith('(Default)'):
|
||||
line = line.replace('(Default)', '')
|
||||
line = line.lstrip(' #\t')
|
||||
key_val = line.split('=')
|
||||
if len(key_val) == 2:
|
||||
if key_val[0].strip() == 'file' and not key_val[1].endswith('csv'):
|
||||
config_dict['data_file'] = key_val[1].strip()
|
||||
elif key_val[0].strip() != 'file':
|
||||
raw_val = key_val[1].strip()
|
||||
val: Union[int, float, str]
|
||||
try:
|
||||
val = int(raw_val)
|
||||
except ValueError:
|
||||
try:
|
||||
val = float(raw_val)
|
||||
except ValueError:
|
||||
if raw_val == "true":
|
||||
val = 1
|
||||
elif raw_val == "false":
|
||||
val = 0
|
||||
else:
|
||||
val = raw_val
|
||||
config_dict[key_val[0].strip()] = val
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
fd.seek(cur_pos)
|
||||
return lineno
|
||||
|
||||
|
||||
def scan_warmup_iters(
|
||||
fd: TextIO, config_dict: Dict[str, Any], lineno: int
|
||||
) -> int:
|
||||
"""
|
||||
Check warmup iterations, if any.
|
||||
"""
|
||||
if 'save_warmup' not in config_dict:
|
||||
return lineno
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
draws_found = 0
|
||||
while len(line) > 0 and not line.startswith('#'):
|
||||
lineno += 1
|
||||
draws_found += 1
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
fd.seek(cur_pos)
|
||||
config_dict['draws_warmup'] = draws_found
|
||||
return lineno
|
||||
|
||||
|
||||
def scan_column_names(
|
||||
fd: TextIO, config_dict: MutableMapping[str, Any], lineno: int
|
||||
) -> int:
|
||||
"""
|
||||
Process columns header, add to config_dict as 'column_names'
|
||||
"""
|
||||
line = fd.readline().strip()
|
||||
lineno += 1
|
||||
config_dict['raw_header'] = line.strip()
|
||||
names = line.split(',')
|
||||
config_dict['column_names'] = tuple(munge_varnames(names))
|
||||
return lineno
|
||||
|
||||
|
||||
def munge_varname(name: str) -> str:
|
||||
if '.' not in name and ':' not in name:
|
||||
return name
|
||||
|
||||
tuple_parts = name.split(':')
|
||||
for i, part in enumerate(tuple_parts):
|
||||
if '.' not in part:
|
||||
continue
|
||||
part = part.replace('.', '[', 1)
|
||||
part = part.replace('.', ',')
|
||||
part += ']'
|
||||
tuple_parts[i] = part
|
||||
|
||||
return '.'.join(tuple_parts)
|
||||
|
||||
|
||||
def munge_varnames(names: List[str]) -> List[str]:
|
||||
"""
|
||||
Change formatting for indices of container var elements
|
||||
from use of dot separator to array-like notation, e.g.,
|
||||
rewrite label ``y_forecast.2.4`` to ``y_forecast[2,4]``.
|
||||
"""
|
||||
if names is None:
|
||||
raise ValueError('missing argument "names"')
|
||||
return [munge_varname(name) for name in names]
|
||||
|
||||
|
||||
def scan_hmc_params(
|
||||
fd: TextIO, config_dict: Dict[str, Any], lineno: int
|
||||
) -> int:
|
||||
"""
|
||||
Scan step size, metric from stan_csv file comment lines.
|
||||
"""
|
||||
metric = config_dict['metric']
|
||||
line = fd.readline().strip()
|
||||
lineno += 1
|
||||
if not line == '# Adaptation terminated':
|
||||
raise ValueError(
|
||||
'line {}: expecting metric, found:\n\t "{}"'.format(lineno, line)
|
||||
)
|
||||
line = fd.readline().strip()
|
||||
lineno += 1
|
||||
label, step_size = line.split('=')
|
||||
if not label.startswith('# Step size'):
|
||||
raise ValueError(
|
||||
'line {}: expecting step size, '
|
||||
'found:\n\t "{}"'.format(lineno, line)
|
||||
)
|
||||
try:
|
||||
float(step_size.strip())
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
'line {}: invalid step size: {}'.format(lineno, step_size)
|
||||
) from e
|
||||
before_metric = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
lineno += 1
|
||||
if metric == 'unit_e':
|
||||
if line.startswith("# No free parameters"):
|
||||
return lineno
|
||||
else:
|
||||
fd.seek(before_metric)
|
||||
return lineno - 1
|
||||
|
||||
if not (
|
||||
(
|
||||
metric == 'diag_e'
|
||||
and line == '# Diagonal elements of inverse mass matrix:'
|
||||
)
|
||||
or (
|
||||
metric == 'dense_e' and line == '# Elements of inverse mass matrix:'
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
'line {}: invalid or missing mass matrix '
|
||||
'specification'.format(lineno)
|
||||
)
|
||||
line = fd.readline().lstrip(' #\t')
|
||||
lineno += 1
|
||||
num_unconstrained_params = len(line.split(','))
|
||||
if metric == 'diag_e':
|
||||
return lineno
|
||||
else:
|
||||
for _ in range(1, num_unconstrained_params):
|
||||
line = fd.readline().lstrip(' #\t')
|
||||
lineno += 1
|
||||
if len(line.split(',')) != num_unconstrained_params:
|
||||
raise ValueError(
|
||||
'line {}: invalid or missing mass matrix '
|
||||
'specification'.format(lineno)
|
||||
)
|
||||
return lineno
|
||||
|
||||
|
||||
def scan_sampling_iters(
|
||||
fd: TextIO, config_dict: Dict[str, Any], lineno: int, is_fixed_param: bool
|
||||
) -> int:
|
||||
"""
|
||||
Parse sampling iteration, save number of iterations to config_dict.
|
||||
Also save number of divergences, max_treedepth hits
|
||||
"""
|
||||
draws_found = 0
|
||||
num_cols = len(config_dict['column_names'])
|
||||
if not is_fixed_param:
|
||||
idx_divergent = config_dict['column_names'].index('divergent__')
|
||||
idx_treedepth = config_dict['column_names'].index('treedepth__')
|
||||
max_treedepth = config_dict['max_depth']
|
||||
ct_divergences = 0
|
||||
ct_max_treedepth = 0
|
||||
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
while len(line) > 0 and not line.startswith('#'):
|
||||
lineno += 1
|
||||
draws_found += 1
|
||||
data = line.split(',')
|
||||
if len(data) != num_cols:
|
||||
raise ValueError(
|
||||
'line {}: bad draw, expecting {} items, found {}\n'.format(
|
||||
lineno, num_cols, len(line.split(','))
|
||||
)
|
||||
+ 'This error could be caused by running out of disk space.\n'
|
||||
'Try clearing up TEMP or setting output_dir to a path'
|
||||
' on another drive.',
|
||||
)
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
if not is_fixed_param:
|
||||
ct_divergences += int(data[idx_divergent]) # type: ignore
|
||||
if int(data[idx_treedepth]) == max_treedepth: # type: ignore
|
||||
ct_max_treedepth += 1
|
||||
|
||||
fd.seek(cur_pos)
|
||||
config_dict['draws_sampling'] = draws_found
|
||||
if not is_fixed_param:
|
||||
config_dict['ct_divergences'] = ct_divergences
|
||||
config_dict['ct_max_treedepth'] = ct_max_treedepth
|
||||
return lineno
|
||||
|
||||
|
||||
def read_metric(path: str) -> List[int]:
|
||||
"""
|
||||
Read metric file in JSON or Rdump format.
|
||||
Return dimensions of entry "inv_metric".
|
||||
"""
|
||||
if path.endswith('.json'):
|
||||
with open(path, 'r') as fd:
|
||||
metric_dict = json.load(fd)
|
||||
if 'inv_metric' in metric_dict:
|
||||
dims_np: np.ndarray = np.asarray(metric_dict['inv_metric'])
|
||||
return list(dims_np.shape)
|
||||
else:
|
||||
raise ValueError(
|
||||
'metric file {}, bad or missing'
|
||||
' entry "inv_metric"'.format(path)
|
||||
)
|
||||
else:
|
||||
dims = list(read_rdump_metric(path))
|
||||
if dims is None:
|
||||
raise ValueError(
|
||||
'metric file {}, bad or missing'
|
||||
' entry "inv_metric"'.format(path)
|
||||
)
|
||||
return dims
|
||||
|
||||
|
||||
def read_rdump_metric(path: str) -> List[int]:
|
||||
"""
|
||||
Find dimensions of variable named 'inv_metric' in Rdump data file.
|
||||
"""
|
||||
metric_dict = rload(path)
|
||||
if metric_dict is None or not (
|
||||
'inv_metric' in metric_dict
|
||||
and isinstance(metric_dict['inv_metric'], np.ndarray)
|
||||
):
|
||||
raise ValueError(
|
||||
'metric file {}, bad or missing entry "inv_metric"'.format(path)
|
||||
)
|
||||
return list(metric_dict['inv_metric'].shape)
|
||||
|
||||
|
||||
def rload(fname: str) -> Optional[Dict[str, Union[int, float, np.ndarray]]]:
|
||||
"""Parse data and parameter variable values from an R dump format file.
|
||||
This parser only supports the subset of R dump data as described
|
||||
in the "Dump Data Format" section of the CmdStan manual, i.e.,
|
||||
scalar, vector, matrix, and array data types.
|
||||
"""
|
||||
data_dict = {}
|
||||
with open(fname, 'r') as fd:
|
||||
lines = fd.readlines()
|
||||
# Variable data may span multiple lines, parse accordingly
|
||||
idx = 0
|
||||
while idx < len(lines) and '<-' not in lines[idx]:
|
||||
idx += 1
|
||||
if idx == len(lines):
|
||||
return None
|
||||
start_idx = idx
|
||||
idx += 1
|
||||
while True:
|
||||
while idx < len(lines) and '<-' not in lines[idx]:
|
||||
idx += 1
|
||||
next_var = idx
|
||||
var_data = ''.join(lines[start_idx:next_var]).replace('\n', '')
|
||||
lhs, rhs = [item.strip() for item in var_data.split('<-')]
|
||||
lhs = lhs.replace('"', '') # strip optional Jags double quotes
|
||||
rhs = rhs.replace('L', '') # strip R long int qualifier
|
||||
data_dict[lhs] = parse_rdump_value(rhs)
|
||||
if idx == len(lines):
|
||||
break
|
||||
start_idx = next_var
|
||||
idx += 1
|
||||
return data_dict
|
||||
|
||||
|
||||
def parse_rdump_value(rhs: str) -> Union[int, float, np.ndarray]:
|
||||
"""Process right hand side of Rdump variable assignment statement.
|
||||
Value is either scalar, vector, or multi-dim structure.
|
||||
Use regex to capture structure values, dimensions.
|
||||
"""
|
||||
pat = re.compile(
|
||||
r'structure\(\s*c\((?P<vals>[^)]*)\)'
|
||||
r'(,\s*\.Dim\s*=\s*c\s*\((?P<dims>[^)]*)\s*\))?\)'
|
||||
)
|
||||
val: Union[int, float, np.ndarray]
|
||||
try:
|
||||
if rhs.startswith('structure'):
|
||||
parse = pat.match(rhs)
|
||||
if parse is None or parse.group('vals') is None:
|
||||
raise ValueError(rhs)
|
||||
vals = [float(v) for v in parse.group('vals').split(',')]
|
||||
val = np.array(vals, order='F')
|
||||
if parse.group('dims') is not None:
|
||||
dims = [int(v) for v in parse.group('dims').split(',')]
|
||||
val = np.array(vals).reshape(dims, order='F')
|
||||
elif rhs.startswith('c(') and rhs.endswith(')'):
|
||||
val = np.array([float(item) for item in rhs[2:-1].split(',')])
|
||||
elif '.' in rhs or 'e' in rhs:
|
||||
val = float(rhs)
|
||||
else:
|
||||
val = int(rhs)
|
||||
except TypeError as e:
|
||||
raise ValueError('bad value in Rdump file: {}'.format(rhs)) from e
|
||||
return val
|
||||
Reference in New Issue
Block a user