487 lines
16 KiB
Python
487 lines
16 KiB
Python
"""
|
|
Utility functions for reading the Stan CSV format
|
|
"""
|
|
import json
|
|
import math
|
|
import re
|
|
from typing import Any, Dict, List, MutableMapping, Optional, TextIO, Union
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP
|
|
|
|
|
|
def check_sampler_csv(
|
|
path: str,
|
|
is_fixed_param: bool = False,
|
|
iter_sampling: Optional[int] = None,
|
|
iter_warmup: Optional[int] = None,
|
|
save_warmup: bool = False,
|
|
thin: Optional[int] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Capture essential config, shape from stan_csv file."""
|
|
meta = scan_sampler_csv(path, is_fixed_param)
|
|
if thin is None:
|
|
thin = _CMDSTAN_THIN
|
|
elif thin > _CMDSTAN_THIN:
|
|
if 'thin' not in meta:
|
|
raise ValueError(
|
|
'bad Stan CSV file {}, '
|
|
'config error, expected thin = {}'.format(path, thin)
|
|
)
|
|
if meta['thin'] != thin:
|
|
raise ValueError(
|
|
'bad Stan CSV file {}, '
|
|
'config error, expected thin = {}, found {}'.format(
|
|
path, thin, meta['thin']
|
|
)
|
|
)
|
|
draws_sampling = iter_sampling
|
|
if draws_sampling is None:
|
|
draws_sampling = _CMDSTAN_SAMPLING
|
|
draws_warmup = iter_warmup
|
|
if draws_warmup is None:
|
|
draws_warmup = _CMDSTAN_WARMUP
|
|
draws_warmup = int(math.ceil(draws_warmup / thin))
|
|
draws_sampling = int(math.ceil(draws_sampling / thin))
|
|
if meta['draws_sampling'] != draws_sampling:
|
|
raise ValueError(
|
|
'bad Stan CSV file {}, expected {} draws, found {}'.format(
|
|
path, draws_sampling, meta['draws_sampling']
|
|
)
|
|
)
|
|
if save_warmup:
|
|
if not ('save_warmup' in meta and meta['save_warmup'] in (1, 'true')):
|
|
raise ValueError(
|
|
'bad Stan CSV file {}, '
|
|
'config error, expected save_warmup = 1'.format(path)
|
|
)
|
|
if meta['draws_warmup'] != draws_warmup:
|
|
raise ValueError(
|
|
'bad Stan CSV file {}, '
|
|
'expected {} warmup draws, found {}'.format(
|
|
path, draws_warmup, meta['draws_warmup']
|
|
)
|
|
)
|
|
return meta
|
|
|
|
|
|
def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
|
|
"""Process sampler stan_csv output file line by line."""
|
|
dict: Dict[str, Any] = {}
|
|
lineno = 0
|
|
with open(path, 'r') as fd:
|
|
try:
|
|
lineno = scan_config(fd, dict, lineno)
|
|
lineno = scan_column_names(fd, dict, lineno)
|
|
if not is_fixed_param:
|
|
lineno = scan_warmup_iters(fd, dict, lineno)
|
|
lineno = scan_hmc_params(fd, dict, lineno)
|
|
lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param)
|
|
except ValueError as e:
|
|
raise ValueError("Error in reading csv file: " + path) from e
|
|
return dict
|
|
|
|
|
|
def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
|
|
"""Process optimizer stan_csv output file line by line."""
|
|
dict: Dict[str, Any] = {}
|
|
lineno = 0
|
|
# scan to find config, header, num saved iters
|
|
with open(path, 'r') as fd:
|
|
lineno = scan_config(fd, dict, lineno)
|
|
lineno = scan_column_names(fd, dict, lineno)
|
|
iters = 0
|
|
for line in fd:
|
|
iters += 1
|
|
if save_iters:
|
|
all_iters: np.ndarray = np.empty(
|
|
(iters, len(dict['column_names'])), dtype=float, order='F'
|
|
)
|
|
# rescan to capture estimates
|
|
with open(path, 'r') as fd:
|
|
for i in range(lineno):
|
|
fd.readline()
|
|
for i in range(iters):
|
|
line = fd.readline().strip()
|
|
if len(line) < 1:
|
|
raise ValueError(
|
|
'cannot parse CSV file {}, error at line {}'.format(
|
|
path, lineno + i
|
|
)
|
|
)
|
|
xs = line.split(',')
|
|
if save_iters:
|
|
all_iters[i, :] = [float(x) for x in xs]
|
|
if i == iters - 1:
|
|
mle: np.ndarray = np.array(xs, dtype=float)
|
|
# pylint: disable=possibly-used-before-assignment
|
|
dict['mle'] = mle
|
|
if save_iters:
|
|
dict['all_iters'] = all_iters
|
|
return dict
|
|
|
|
|
|
def scan_generic_csv(path: str) -> Dict[str, Any]:
|
|
"""Process laplace stan_csv output file line by line."""
|
|
dict: Dict[str, Any] = {}
|
|
lineno = 0
|
|
with open(path, 'r') as fd:
|
|
lineno = scan_config(fd, dict, lineno)
|
|
lineno = scan_column_names(fd, dict, lineno)
|
|
return dict
|
|
|
|
|
|
def scan_variational_csv(path: str) -> Dict[str, Any]:
|
|
"""Process advi stan_csv output file line by line."""
|
|
dict: Dict[str, Any] = {}
|
|
lineno = 0
|
|
with open(path, 'r') as fd:
|
|
lineno = scan_config(fd, dict, lineno)
|
|
lineno = scan_column_names(fd, dict, lineno)
|
|
line = fd.readline().lstrip(' #\t').rstrip()
|
|
lineno += 1
|
|
if line.startswith('Stepsize adaptation complete.'):
|
|
line = fd.readline().lstrip(' #\t\n')
|
|
lineno += 1
|
|
if not line.startswith('eta'):
|
|
raise ValueError(
|
|
'line {}: expecting eta, found:\n\t "{}"'.format(
|
|
lineno, line
|
|
)
|
|
)
|
|
_, eta = line.split('=')
|
|
dict['eta'] = float(eta)
|
|
line = fd.readline().lstrip(' #\t\n')
|
|
lineno += 1
|
|
xs = line.split(',')
|
|
variational_mean = [float(x) for x in xs]
|
|
dict['variational_mean'] = np.array(variational_mean)
|
|
dict['variational_sample'] = pd.read_csv(
|
|
path,
|
|
comment='#',
|
|
skiprows=lineno,
|
|
header=None,
|
|
float_precision='high',
|
|
).to_numpy()
|
|
return dict
|
|
|
|
|
|
def scan_config(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int:
|
|
"""
|
|
Scan initial stan_csv file comments lines and
|
|
save non-default configuration information to config_dict.
|
|
"""
|
|
cur_pos = fd.tell()
|
|
line = fd.readline().strip()
|
|
while len(line) > 0 and line.startswith('#'):
|
|
lineno += 1
|
|
if line.endswith('(Default)'):
|
|
line = line.replace('(Default)', '')
|
|
line = line.lstrip(' #\t')
|
|
key_val = line.split('=')
|
|
if len(key_val) == 2:
|
|
if key_val[0].strip() == 'file' and not key_val[1].endswith('csv'):
|
|
config_dict['data_file'] = key_val[1].strip()
|
|
elif key_val[0].strip() != 'file':
|
|
raw_val = key_val[1].strip()
|
|
val: Union[int, float, str]
|
|
try:
|
|
val = int(raw_val)
|
|
except ValueError:
|
|
try:
|
|
val = float(raw_val)
|
|
except ValueError:
|
|
if raw_val == "true":
|
|
val = 1
|
|
elif raw_val == "false":
|
|
val = 0
|
|
else:
|
|
val = raw_val
|
|
config_dict[key_val[0].strip()] = val
|
|
cur_pos = fd.tell()
|
|
line = fd.readline().strip()
|
|
fd.seek(cur_pos)
|
|
return lineno
|
|
|
|
|
|
def scan_warmup_iters(
|
|
fd: TextIO, config_dict: Dict[str, Any], lineno: int
|
|
) -> int:
|
|
"""
|
|
Check warmup iterations, if any.
|
|
"""
|
|
if 'save_warmup' not in config_dict:
|
|
return lineno
|
|
cur_pos = fd.tell()
|
|
line = fd.readline().strip()
|
|
draws_found = 0
|
|
while len(line) > 0 and not line.startswith('#'):
|
|
lineno += 1
|
|
draws_found += 1
|
|
cur_pos = fd.tell()
|
|
line = fd.readline().strip()
|
|
fd.seek(cur_pos)
|
|
config_dict['draws_warmup'] = draws_found
|
|
return lineno
|
|
|
|
|
|
def scan_column_names(
|
|
fd: TextIO, config_dict: MutableMapping[str, Any], lineno: int
|
|
) -> int:
|
|
"""
|
|
Process columns header, add to config_dict as 'column_names'
|
|
"""
|
|
line = fd.readline().strip()
|
|
lineno += 1
|
|
config_dict['raw_header'] = line.strip()
|
|
names = line.split(',')
|
|
config_dict['column_names'] = tuple(munge_varnames(names))
|
|
return lineno
|
|
|
|
|
|
def munge_varname(name: str) -> str:
|
|
if '.' not in name and ':' not in name:
|
|
return name
|
|
|
|
tuple_parts = name.split(':')
|
|
for i, part in enumerate(tuple_parts):
|
|
if '.' not in part:
|
|
continue
|
|
part = part.replace('.', '[', 1)
|
|
part = part.replace('.', ',')
|
|
part += ']'
|
|
tuple_parts[i] = part
|
|
|
|
return '.'.join(tuple_parts)
|
|
|
|
|
|
def munge_varnames(names: List[str]) -> List[str]:
|
|
"""
|
|
Change formatting for indices of container var elements
|
|
from use of dot separator to array-like notation, e.g.,
|
|
rewrite label ``y_forecast.2.4`` to ``y_forecast[2,4]``.
|
|
"""
|
|
if names is None:
|
|
raise ValueError('missing argument "names"')
|
|
return [munge_varname(name) for name in names]
|
|
|
|
|
|
def scan_hmc_params(
|
|
fd: TextIO, config_dict: Dict[str, Any], lineno: int
|
|
) -> int:
|
|
"""
|
|
Scan step size, metric from stan_csv file comment lines.
|
|
"""
|
|
metric = config_dict['metric']
|
|
line = fd.readline().strip()
|
|
lineno += 1
|
|
if not line == '# Adaptation terminated':
|
|
raise ValueError(
|
|
'line {}: expecting metric, found:\n\t "{}"'.format(lineno, line)
|
|
)
|
|
line = fd.readline().strip()
|
|
lineno += 1
|
|
label, step_size = line.split('=')
|
|
if not label.startswith('# Step size'):
|
|
raise ValueError(
|
|
'line {}: expecting step size, '
|
|
'found:\n\t "{}"'.format(lineno, line)
|
|
)
|
|
try:
|
|
float(step_size.strip())
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
'line {}: invalid step size: {}'.format(lineno, step_size)
|
|
) from e
|
|
before_metric = fd.tell()
|
|
line = fd.readline().strip()
|
|
lineno += 1
|
|
if metric == 'unit_e':
|
|
if line.startswith("# No free parameters"):
|
|
return lineno
|
|
else:
|
|
fd.seek(before_metric)
|
|
return lineno - 1
|
|
|
|
if not (
|
|
(
|
|
metric == 'diag_e'
|
|
and line == '# Diagonal elements of inverse mass matrix:'
|
|
)
|
|
or (
|
|
metric == 'dense_e' and line == '# Elements of inverse mass matrix:'
|
|
)
|
|
):
|
|
raise ValueError(
|
|
'line {}: invalid or missing mass matrix '
|
|
'specification'.format(lineno)
|
|
)
|
|
line = fd.readline().lstrip(' #\t')
|
|
lineno += 1
|
|
num_unconstrained_params = len(line.split(','))
|
|
if metric == 'diag_e':
|
|
return lineno
|
|
else:
|
|
for _ in range(1, num_unconstrained_params):
|
|
line = fd.readline().lstrip(' #\t')
|
|
lineno += 1
|
|
if len(line.split(',')) != num_unconstrained_params:
|
|
raise ValueError(
|
|
'line {}: invalid or missing mass matrix '
|
|
'specification'.format(lineno)
|
|
)
|
|
return lineno
|
|
|
|
|
|
def scan_sampling_iters(
|
|
fd: TextIO, config_dict: Dict[str, Any], lineno: int, is_fixed_param: bool
|
|
) -> int:
|
|
"""
|
|
Parse sampling iteration, save number of iterations to config_dict.
|
|
Also save number of divergences, max_treedepth hits
|
|
"""
|
|
draws_found = 0
|
|
num_cols = len(config_dict['column_names'])
|
|
if not is_fixed_param:
|
|
idx_divergent = config_dict['column_names'].index('divergent__')
|
|
idx_treedepth = config_dict['column_names'].index('treedepth__')
|
|
max_treedepth = config_dict['max_depth']
|
|
ct_divergences = 0
|
|
ct_max_treedepth = 0
|
|
|
|
cur_pos = fd.tell()
|
|
line = fd.readline().strip()
|
|
while len(line) > 0 and not line.startswith('#'):
|
|
lineno += 1
|
|
draws_found += 1
|
|
data = line.split(',')
|
|
if len(data) != num_cols:
|
|
raise ValueError(
|
|
'line {}: bad draw, expecting {} items, found {}\n'.format(
|
|
lineno, num_cols, len(line.split(','))
|
|
)
|
|
+ 'This error could be caused by running out of disk space.\n'
|
|
'Try clearing up TEMP or setting output_dir to a path'
|
|
' on another drive.',
|
|
)
|
|
cur_pos = fd.tell()
|
|
line = fd.readline().strip()
|
|
if not is_fixed_param:
|
|
ct_divergences += int(data[idx_divergent]) # type: ignore
|
|
if int(data[idx_treedepth]) == max_treedepth: # type: ignore
|
|
ct_max_treedepth += 1
|
|
|
|
fd.seek(cur_pos)
|
|
config_dict['draws_sampling'] = draws_found
|
|
if not is_fixed_param:
|
|
config_dict['ct_divergences'] = ct_divergences
|
|
config_dict['ct_max_treedepth'] = ct_max_treedepth
|
|
return lineno
|
|
|
|
|
|
def read_metric(path: str) -> List[int]:
|
|
"""
|
|
Read metric file in JSON or Rdump format.
|
|
Return dimensions of entry "inv_metric".
|
|
"""
|
|
if path.endswith('.json'):
|
|
with open(path, 'r') as fd:
|
|
metric_dict = json.load(fd)
|
|
if 'inv_metric' in metric_dict:
|
|
dims_np: np.ndarray = np.asarray(metric_dict['inv_metric'])
|
|
return list(dims_np.shape)
|
|
else:
|
|
raise ValueError(
|
|
'metric file {}, bad or missing'
|
|
' entry "inv_metric"'.format(path)
|
|
)
|
|
else:
|
|
dims = list(read_rdump_metric(path))
|
|
if dims is None:
|
|
raise ValueError(
|
|
'metric file {}, bad or missing'
|
|
' entry "inv_metric"'.format(path)
|
|
)
|
|
return dims
|
|
|
|
|
|
def read_rdump_metric(path: str) -> List[int]:
|
|
"""
|
|
Find dimensions of variable named 'inv_metric' in Rdump data file.
|
|
"""
|
|
metric_dict = rload(path)
|
|
if metric_dict is None or not (
|
|
'inv_metric' in metric_dict
|
|
and isinstance(metric_dict['inv_metric'], np.ndarray)
|
|
):
|
|
raise ValueError(
|
|
'metric file {}, bad or missing entry "inv_metric"'.format(path)
|
|
)
|
|
return list(metric_dict['inv_metric'].shape)
|
|
|
|
|
|
def rload(fname: str) -> Optional[Dict[str, Union[int, float, np.ndarray]]]:
|
|
"""Parse data and parameter variable values from an R dump format file.
|
|
This parser only supports the subset of R dump data as described
|
|
in the "Dump Data Format" section of the CmdStan manual, i.e.,
|
|
scalar, vector, matrix, and array data types.
|
|
"""
|
|
data_dict = {}
|
|
with open(fname, 'r') as fd:
|
|
lines = fd.readlines()
|
|
# Variable data may span multiple lines, parse accordingly
|
|
idx = 0
|
|
while idx < len(lines) and '<-' not in lines[idx]:
|
|
idx += 1
|
|
if idx == len(lines):
|
|
return None
|
|
start_idx = idx
|
|
idx += 1
|
|
while True:
|
|
while idx < len(lines) and '<-' not in lines[idx]:
|
|
idx += 1
|
|
next_var = idx
|
|
var_data = ''.join(lines[start_idx:next_var]).replace('\n', '')
|
|
lhs, rhs = [item.strip() for item in var_data.split('<-')]
|
|
lhs = lhs.replace('"', '') # strip optional Jags double quotes
|
|
rhs = rhs.replace('L', '') # strip R long int qualifier
|
|
data_dict[lhs] = parse_rdump_value(rhs)
|
|
if idx == len(lines):
|
|
break
|
|
start_idx = next_var
|
|
idx += 1
|
|
return data_dict
|
|
|
|
|
|
def parse_rdump_value(rhs: str) -> Union[int, float, np.ndarray]:
|
|
"""Process right hand side of Rdump variable assignment statement.
|
|
Value is either scalar, vector, or multi-dim structure.
|
|
Use regex to capture structure values, dimensions.
|
|
"""
|
|
pat = re.compile(
|
|
r'structure\(\s*c\((?P<vals>[^)]*)\)'
|
|
r'(,\s*\.Dim\s*=\s*c\s*\((?P<dims>[^)]*)\s*\))?\)'
|
|
)
|
|
val: Union[int, float, np.ndarray]
|
|
try:
|
|
if rhs.startswith('structure'):
|
|
parse = pat.match(rhs)
|
|
if parse is None or parse.group('vals') is None:
|
|
raise ValueError(rhs)
|
|
vals = [float(v) for v in parse.group('vals').split(',')]
|
|
val = np.array(vals, order='F')
|
|
if parse.group('dims') is not None:
|
|
dims = [int(v) for v in parse.group('dims').split(',')]
|
|
val = np.array(vals).reshape(dims, order='F')
|
|
elif rhs.startswith('c(') and rhs.endswith(')'):
|
|
val = np.array([float(item) for item in rhs[2:-1].split(',')])
|
|
elif '.' in rhs or 'e' in rhs:
|
|
val = float(rhs)
|
|
else:
|
|
val = int(rhs)
|
|
except TypeError as e:
|
|
raise ValueError('bad value in Rdump file: {}'.format(rhs)) from e
|
|
return val
|