some new features
This commit is contained in:
8
.venv/bin/install_cmdstan
Executable file
8
.venv/bin/install_cmdstan
Executable file
@ -0,0 +1,8 @@
|
||||
#!/home/ilgazc/PycharmProjects/TimeSeriesAnalysis/.venv/bin/python3.12
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import sys
|
||||
from cmdstanpy.install_cmdstan import __main__
|
||||
if __name__ == '__main__':
|
||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
||||
sys.exit(__main__())
|
||||
8
.venv/bin/install_cxx_toolchain
Executable file
8
.venv/bin/install_cxx_toolchain
Executable file
@ -0,0 +1,8 @@
|
||||
#!/home/ilgazc/PycharmProjects/TimeSeriesAnalysis/.venv/bin/python3.12
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import sys
|
||||
from cmdstanpy.install_cxx_toolchain import __main__
|
||||
if __name__ == '__main__':
|
||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
||||
sys.exit(__main__())
|
||||
8
.venv/bin/tqdm
Executable file
8
.venv/bin/tqdm
Executable file
@ -0,0 +1,8 @@
|
||||
#!/home/ilgazc/PycharmProjects/TimeSeriesAnalysis/.venv/bin/python3.12
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import sys
|
||||
from tqdm.cli import main
|
||||
if __name__ == '__main__':
|
||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
||||
sys.exit(main())
|
||||
@ -0,0 +1 @@
|
||||
pip
|
||||
@ -0,0 +1,15 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2019, Stan Developers and their Assignees
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
@ -0,0 +1,102 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: cmdstanpy
|
||||
Version: 1.2.5
|
||||
Summary: Python interface to CmdStan
|
||||
Author: Stan Dev Team
|
||||
License: BSD-3-Clause
|
||||
Project-URL: Homepage, https://github.com/stan-dev/cmdstanpy
|
||||
Project-URL: Bug Tracker, https://github.com/stan-dev/cmdstanpy/issues
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: License :: OSI Approved :: BSD License
|
||||
Classifier: Operating System :: OS Independent
|
||||
Classifier: Development Status :: 5 - Production/Stable
|
||||
Classifier: Intended Audience :: Science/Research
|
||||
Classifier: Natural Language :: English
|
||||
Classifier: Programming Language :: Python
|
||||
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
||||
Requires-Python: >=3.8
|
||||
Description-Content-Type: text/markdown
|
||||
License-File: LICENSE.md
|
||||
Requires-Dist: pandas
|
||||
Requires-Dist: numpy>=1.21
|
||||
Requires-Dist: tqdm
|
||||
Requires-Dist: stanio<2.0.0,>=0.4.0
|
||||
Provides-Extra: all
|
||||
Requires-Dist: xarray; extra == "all"
|
||||
Provides-Extra: test
|
||||
Requires-Dist: flake8; extra == "test"
|
||||
Requires-Dist: pylint; extra == "test"
|
||||
Requires-Dist: pytest; extra == "test"
|
||||
Requires-Dist: pytest-cov; extra == "test"
|
||||
Requires-Dist: pytest-order; extra == "test"
|
||||
Requires-Dist: mypy; extra == "test"
|
||||
Requires-Dist: xarray; extra == "test"
|
||||
Provides-Extra: docs
|
||||
Requires-Dist: sphinx<6,>5; extra == "docs"
|
||||
Requires-Dist: pydata-sphinx-theme<0.9; extra == "docs"
|
||||
Requires-Dist: nbsphinx; extra == "docs"
|
||||
Requires-Dist: ipython; extra == "docs"
|
||||
Requires-Dist: ipykernel; extra == "docs"
|
||||
Requires-Dist: ipywidgets; extra == "docs"
|
||||
Requires-Dist: sphinx-copybutton; extra == "docs"
|
||||
Requires-Dist: xarray; extra == "docs"
|
||||
Requires-Dist: matplotlib; extra == "docs"
|
||||
|
||||
# CmdStanPy
|
||||
|
||||
[](https://codecov.io/gh/stan-dev/cmdstanpy)
|
||||
|
||||
|
||||
CmdStanPy is a lightweight pure-Python interface to CmdStan which provides access to the Stan compiler and all inference algorithms. It supports both development and production workflows. Because model development and testing may require many iterations, the defaults favor development mode and therefore output files are stored on a temporary filesystem. Non-default options allow all aspects of a run to be specified so that scripts can be used to distributed analysis jobs across nodes and machines.
|
||||
|
||||
CmdStanPy is distributed via PyPi: https://pypi.org/project/cmdstanpy/
|
||||
|
||||
or Conda Forge: https://anaconda.org/conda-forge/cmdstanpy
|
||||
|
||||
### Goals
|
||||
|
||||
- Clean interface to Stan services so that CmdStanPy can keep up with Stan releases.
|
||||
|
||||
- Provide access to all CmdStan inference methods.
|
||||
|
||||
- Easy to install,
|
||||
+ minimal Python library dependencies: numpy, pandas
|
||||
+ Python code doesn't interface directly with c++, only calls compiled executables
|
||||
|
||||
- Modular - CmdStanPy produces a MCMC sample (or point estimate) from the posterior; other packages do analysis and visualization.
|
||||
|
||||
- Low memory overhead - by default, minimal memory used above that required by CmdStanPy; objects run CmdStan programs and track CmdStan input and output files.
|
||||
|
||||
|
||||
### Source Repository
|
||||
|
||||
CmdStanPy and CmdStan are available from GitHub: https://github.com/stan-dev/cmdstanpy and https://github.com/stan-dev/cmdstan
|
||||
|
||||
|
||||
### Docs
|
||||
|
||||
The latest release documentation is hosted on https://mc-stan.org/cmdstanpy, older release versions are available from readthedocs: https://cmdstanpy.readthedocs.io
|
||||
|
||||
### Licensing
|
||||
|
||||
The CmdStanPy, CmdStan, and the core Stan C++ code are licensed under new BSD.
|
||||
|
||||
### Example
|
||||
|
||||
```python
|
||||
import os
|
||||
from cmdstanpy import cmdstan_path, CmdStanModel
|
||||
|
||||
# specify locations of Stan program file and data
|
||||
stan_file = os.path.join(cmdstan_path(), 'examples', 'bernoulli', 'bernoulli.stan')
|
||||
data_file = os.path.join(cmdstan_path(), 'examples', 'bernoulli', 'bernoulli.data.json')
|
||||
|
||||
# instantiate a model; compiles the Stan program by default
|
||||
model = CmdStanModel(stan_file=stan_file)
|
||||
|
||||
# obtain a posterior sample from the model conditioned on the data
|
||||
fit = model.sample(chains=4, data=data_file)
|
||||
|
||||
# summarize the results (wraps CmdStan `bin/stansummary`):
|
||||
fit.summary()
|
||||
```
|
||||
@ -0,0 +1,60 @@
|
||||
../../../bin/install_cmdstan,sha256=wXQ7YJaxaA94uHRfLPrP1r4l4y3WbOu-Oq8hesC6GRs,284
|
||||
../../../bin/install_cxx_toolchain,sha256=OQqGRHSWnIewU1OYztVUdB9b1p71QR_jvOPdIDCi_uA,290
|
||||
cmdstanpy-1.2.5.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
cmdstanpy-1.2.5.dist-info/LICENSE.md,sha256=idyB5B6vvYuD6l6IVtMLDi0q-IUMmhKkS5MnM8XgTvI,1526
|
||||
cmdstanpy-1.2.5.dist-info/METADATA,sha256=0ScvA_iMXhi1MwJj_3dH1Dcm2w8unLHbv8K1Y7RtfV4,4050
|
||||
cmdstanpy-1.2.5.dist-info/RECORD,,
|
||||
cmdstanpy-1.2.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
||||
cmdstanpy-1.2.5.dist-info/entry_points.txt,sha256=jMCL_dUqeodJmm8BtARyRkIvkL2m8AvIwpoduddo01Y,136
|
||||
cmdstanpy-1.2.5.dist-info/top_level.txt,sha256=DymymE6zsANoee61D6GcZ9I2c9H2zrrgv12IP1Ob_nA,10
|
||||
cmdstanpy/__init__.py,sha256=VA2CIWvIFkE9FXFUi0DXnAlIDpzuDXGn66lMOX13tv4,1290
|
||||
cmdstanpy/__pycache__/__init__.cpython-312.pyc,,
|
||||
cmdstanpy/__pycache__/_version.cpython-312.pyc,,
|
||||
cmdstanpy/__pycache__/cmdstan_args.cpython-312.pyc,,
|
||||
cmdstanpy/__pycache__/compilation.cpython-312.pyc,,
|
||||
cmdstanpy/__pycache__/install_cmdstan.cpython-312.pyc,,
|
||||
cmdstanpy/__pycache__/install_cxx_toolchain.cpython-312.pyc,,
|
||||
cmdstanpy/__pycache__/model.cpython-312.pyc,,
|
||||
cmdstanpy/__pycache__/progress.cpython-312.pyc,,
|
||||
cmdstanpy/_version.py,sha256=zN4cqI6KhD4VXPuuWSYE-PKGpLSwPmM8qkOM-67JsyQ,42
|
||||
cmdstanpy/cmdstan_args.py,sha256=rKpQrym_ltoa3Qa0mUJMR6jI1v4_mhjmtTd-jlX24ig,39752
|
||||
cmdstanpy/compilation.py,sha256=rBTbFiTxautFRx_cWWdZN30h8W4iKsaVreBhBp3QYC4,20806
|
||||
cmdstanpy/install_cmdstan.py,sha256=61mVVIC7Bggl1oC7EZvi914uc9tJmqfSTPYWUSE0X54,23553
|
||||
cmdstanpy/install_cxx_toolchain.py,sha256=xNyxwUeYyM5vJjnxporefK77SsQUvvM78JeZnFQtKdI,11704
|
||||
cmdstanpy/model.py,sha256=5HtQvSnmkVYOJVH7aekF9NKZQmUz8VzthqoxJYGkCyY,89664
|
||||
cmdstanpy/progress.py,sha256=k5OQgEpUgh8p7VfMNX23uQcaoKuHA_mJ2XGOAKjUJyY,1317
|
||||
cmdstanpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
cmdstanpy/stanfit/__init__.py,sha256=WxvdlTPU1XvxGOIqIF5OKR71O538TDGcBFBE_NLrgAg,10356
|
||||
cmdstanpy/stanfit/__pycache__/__init__.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/__pycache__/gq.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/__pycache__/laplace.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/__pycache__/mcmc.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/__pycache__/metadata.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/__pycache__/mle.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/__pycache__/pathfinder.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/__pycache__/runset.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/__pycache__/vb.cpython-312.pyc,,
|
||||
cmdstanpy/stanfit/gq.py,sha256=ppUEuOqA7bUnxuahKSqvJ80NB52dBoYVdetqIRpeloM,26049
|
||||
cmdstanpy/stanfit/laplace.py,sha256=8St71_dKSQTFw7LKCUGxviGufiMAzlb4IZ2jZ1lcQBI,9618
|
||||
cmdstanpy/stanfit/mcmc.py,sha256=l2sc4gjLm7cvCZNNgN4EPzMEVbbYc7Gu6vyJ1FtyMqg,31735
|
||||
cmdstanpy/stanfit/metadata.py,sha256=kcQ5-shQ_F3RYAbguc5V_SzWP8KWi2pWrObxFLWpPjg,1596
|
||||
cmdstanpy/stanfit/mle.py,sha256=aAk7M1cDhRI-6GRbdqqAcsVVReFhKxReUWSEllTADuM,10458
|
||||
cmdstanpy/stanfit/pathfinder.py,sha256=Iz-JjZyDkJ8IGbZDr0_Tnjh0iEGlClNiB89sqFRVXn8,8322
|
||||
cmdstanpy/stanfit/runset.py,sha256=0nD_Aq5Jyl6cID-MNWHS-ZFJ786osJR9NmJBqs7JcS0,10851
|
||||
cmdstanpy/stanfit/vb.py,sha256=kJvjsUqy9O7pKWQoVaS0R4bh7DpBKk-aD1iq_QpSktA,8502
|
||||
cmdstanpy/utils/__init__.py,sha256=irzAF4bKE-NNWoAJKKH0VNKPnFjUBsWkylfdzX6YYpo,3740
|
||||
cmdstanpy/utils/__pycache__/__init__.cpython-312.pyc,,
|
||||
cmdstanpy/utils/__pycache__/cmdstan.cpython-312.pyc,,
|
||||
cmdstanpy/utils/__pycache__/command.cpython-312.pyc,,
|
||||
cmdstanpy/utils/__pycache__/data_munging.cpython-312.pyc,,
|
||||
cmdstanpy/utils/__pycache__/filesystem.cpython-312.pyc,,
|
||||
cmdstanpy/utils/__pycache__/json.cpython-312.pyc,,
|
||||
cmdstanpy/utils/__pycache__/logging.cpython-312.pyc,,
|
||||
cmdstanpy/utils/__pycache__/stancsv.cpython-312.pyc,,
|
||||
cmdstanpy/utils/cmdstan.py,sha256=uRHFVB955k_nFbR02DUpX-q9xVZzYqnKjjmMJD6L9xk,19196
|
||||
cmdstanpy/utils/command.py,sha256=1nPeOI8Gn6r-WAb3TAe1mqweK-1K0aJRmMulYwdWxNk,3229
|
||||
cmdstanpy/utils/data_munging.py,sha256=Gw764AKLIzWu9LvO-N7CgUjcIzUOnVANX6khkb-m1Gk,1245
|
||||
cmdstanpy/utils/filesystem.py,sha256=t4HHG0IESmUVraF_WlSQs8JRzkGBKJUsj2gSCdt4UHQ,7099
|
||||
cmdstanpy/utils/json.py,sha256=rFQwxTr4OCTUMhBAuCNLQY4rRrwY5m1c9ufw0081_XM,132
|
||||
cmdstanpy/utils/logging.py,sha256=PipH_4YiZdoe5C9bTf5GEPhoI7zPHRx4VPe4AUCbHvw,699
|
||||
cmdstanpy/utils/stancsv.py,sha256=oB1V4dvDgCywoADew39wrbTrSZVXZXvoq3xFUvsiOZ4,16455
|
||||
@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (75.6.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
[console_scripts]
|
||||
install_cmdstan = cmdstanpy.install_cmdstan:__main__
|
||||
install_cxx_toolchain = cmdstanpy.install_cxx_toolchain:__main__
|
||||
@ -0,0 +1 @@
|
||||
cmdstanpy
|
||||
66
.venv/lib/python3.12/site-packages/cmdstanpy/__init__.py
Normal file
66
.venv/lib/python3.12/site-packages/cmdstanpy/__init__.py
Normal file
@ -0,0 +1,66 @@
|
||||
# pylint: disable=wrong-import-position
|
||||
"""CmdStanPy Module"""
|
||||
|
||||
import atexit
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
_TMPDIR = tempfile.mkdtemp()
|
||||
_CMDSTAN_WARMUP = 1000
|
||||
_CMDSTAN_SAMPLING = 1000
|
||||
_CMDSTAN_THIN = 1
|
||||
_CMDSTAN_REFRESH = 100
|
||||
_DOT_CMDSTAN = '.cmdstan'
|
||||
|
||||
|
||||
def _cleanup_tmpdir() -> None:
|
||||
"""Force deletion of _TMPDIR."""
|
||||
shutil.rmtree(_TMPDIR, ignore_errors=True)
|
||||
|
||||
|
||||
atexit.register(_cleanup_tmpdir)
|
||||
|
||||
|
||||
from ._version import __version__ # noqa
|
||||
from .compilation import compile_stan_file, format_stan_file
|
||||
from .install_cmdstan import rebuild_cmdstan
|
||||
from .model import CmdStanModel
|
||||
from .stanfit import (
|
||||
CmdStanGQ,
|
||||
CmdStanLaplace,
|
||||
CmdStanMCMC,
|
||||
CmdStanMLE,
|
||||
CmdStanPathfinder,
|
||||
CmdStanVB,
|
||||
from_csv,
|
||||
)
|
||||
from .utils import (
|
||||
cmdstan_path,
|
||||
cmdstan_version,
|
||||
install_cmdstan,
|
||||
set_cmdstan_path,
|
||||
set_make_env,
|
||||
show_versions,
|
||||
write_stan_json,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'set_cmdstan_path',
|
||||
'cmdstan_path',
|
||||
'set_make_env',
|
||||
'install_cmdstan',
|
||||
'compile_stan_file',
|
||||
'format_stan_file',
|
||||
'CmdStanMCMC',
|
||||
'CmdStanMLE',
|
||||
'CmdStanGQ',
|
||||
'CmdStanVB',
|
||||
'CmdStanLaplace',
|
||||
'CmdStanPathfinder',
|
||||
'CmdStanModel',
|
||||
'from_csv',
|
||||
'write_stan_json',
|
||||
'show_versions',
|
||||
'rebuild_cmdstan',
|
||||
'cmdstan_version',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
3
.venv/lib/python3.12/site-packages/cmdstanpy/_version.py
Normal file
3
.venv/lib/python3.12/site-packages/cmdstanpy/_version.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""PyPi Version"""
|
||||
|
||||
__version__ = '1.2.5'
|
||||
1005
.venv/lib/python3.12/site-packages/cmdstanpy/cmdstan_args.py
Normal file
1005
.venv/lib/python3.12/site-packages/cmdstanpy/cmdstan_args.py
Normal file
File diff suppressed because it is too large
Load Diff
579
.venv/lib/python3.12/site-packages/cmdstanpy/compilation.py
Normal file
579
.venv/lib/python3.12/site-packages/cmdstanpy/compilation.py
Normal file
@ -0,0 +1,579 @@
|
||||
"""
|
||||
Makefile options for stanc and C++ compilers
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from cmdstanpy.utils import get_logger
|
||||
from cmdstanpy.utils.cmdstan import (
|
||||
EXTENSION,
|
||||
cmdstan_path,
|
||||
cmdstan_version,
|
||||
cmdstan_version_before,
|
||||
)
|
||||
from cmdstanpy.utils.command import do_command
|
||||
from cmdstanpy.utils.filesystem import SanitizedOrTmpFilePath
|
||||
|
||||
STANC_OPTS = [
|
||||
'O',
|
||||
'O0',
|
||||
'O1',
|
||||
'Oexperimental',
|
||||
'allow-undefined',
|
||||
'use-opencl',
|
||||
'warn-uninitialized',
|
||||
'include-paths',
|
||||
'name',
|
||||
'warn-pedantic',
|
||||
]
|
||||
|
||||
# TODO(2.0): remove
|
||||
STANC_DEPRECATED_OPTS = {
|
||||
'allow_undefined': 'allow-undefined',
|
||||
'include_paths': 'include-paths',
|
||||
}
|
||||
|
||||
STANC_IGNORE_OPTS = [
|
||||
'debug-lex',
|
||||
'debug-parse',
|
||||
'debug-ast',
|
||||
'debug-decorated-ast',
|
||||
'debug-generate-data',
|
||||
'debug-mir',
|
||||
'debug-mir-pretty',
|
||||
'debug-optimized-mir',
|
||||
'debug-optimized-mir-pretty',
|
||||
'debug-transformed-mir',
|
||||
'debug-transformed-mir-pretty',
|
||||
'dump-stan-math-signatures',
|
||||
'auto-format',
|
||||
'print-canonical',
|
||||
'print-cpp',
|
||||
'o',
|
||||
'help',
|
||||
'version',
|
||||
]
|
||||
|
||||
OptionalPath = Union[str, os.PathLike, None]
|
||||
|
||||
|
||||
# TODO(2.0): can remove add function and other logic
|
||||
class CompilerOptions:
|
||||
"""
|
||||
User-specified flags for stanc and C++ compiler.
|
||||
|
||||
Attributes:
|
||||
stanc_options - stanc compiler flags, options
|
||||
cpp_options - makefile options (NAME=value)
|
||||
user_header - path to a user .hpp file to include during compilation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
stanc_options: Optional[Dict[str, Any]] = None,
|
||||
cpp_options: Optional[Dict[str, Any]] = None,
|
||||
user_header: OptionalPath = None,
|
||||
) -> None:
|
||||
"""Initialize object."""
|
||||
self._stanc_options = stanc_options if stanc_options is not None else {}
|
||||
self._cpp_options = cpp_options if cpp_options is not None else {}
|
||||
self._user_header = str(user_header) if user_header is not None else ''
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return 'stanc_options={}, cpp_options={}'.format(
|
||||
self._stanc_options, self._cpp_options
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Overrides the default implementation"""
|
||||
if self.is_empty() and other is None: # equiv w/r/t compiler
|
||||
return True
|
||||
if not isinstance(other, CompilerOptions):
|
||||
return False
|
||||
return (
|
||||
self._stanc_options == other.stanc_options
|
||||
and self._cpp_options == other.cpp_options
|
||||
and self._user_header == other.user_header
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""True if no options specified."""
|
||||
return (
|
||||
self._stanc_options == {}
|
||||
and self._cpp_options == {}
|
||||
and self._user_header == ''
|
||||
)
|
||||
|
||||
@property
|
||||
def stanc_options(self) -> Dict[str, Union[bool, int, str, Iterable[str]]]:
|
||||
"""Stanc compiler options."""
|
||||
return self._stanc_options
|
||||
|
||||
@property
|
||||
def cpp_options(self) -> Dict[str, Union[bool, int]]:
|
||||
"""C++ compiler options."""
|
||||
return self._cpp_options
|
||||
|
||||
@property
|
||||
def user_header(self) -> str:
|
||||
"""user header."""
|
||||
return self._user_header
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Check compiler args.
|
||||
Raise ValueError if invalid options are found.
|
||||
"""
|
||||
self.validate_stanc_opts()
|
||||
self.validate_cpp_opts()
|
||||
self.validate_user_header()
|
||||
|
||||
def validate_stanc_opts(self) -> None:
|
||||
"""
|
||||
Check stanc compiler args and consistency between stanc and C++ options.
|
||||
Raise ValueError if bad config is found.
|
||||
"""
|
||||
# pylint: disable=no-member
|
||||
if self._stanc_options is None:
|
||||
return
|
||||
ignore = []
|
||||
paths = None
|
||||
has_o_flag = False
|
||||
|
||||
for deprecated, replacement in STANC_DEPRECATED_OPTS.items():
|
||||
if deprecated in self._stanc_options:
|
||||
if replacement:
|
||||
get_logger().warning(
|
||||
'compiler option "%s" is deprecated, use "%s" instead',
|
||||
deprecated,
|
||||
replacement,
|
||||
)
|
||||
self._stanc_options[replacement] = copy(
|
||||
self._stanc_options[deprecated]
|
||||
)
|
||||
del self._stanc_options[deprecated]
|
||||
else:
|
||||
get_logger().warning(
|
||||
'compiler option "%s" is deprecated and '
|
||||
'should not be used',
|
||||
deprecated,
|
||||
)
|
||||
for key, val in self._stanc_options.items():
|
||||
if key in STANC_IGNORE_OPTS:
|
||||
get_logger().info('ignoring compiler option: %s', key)
|
||||
ignore.append(key)
|
||||
elif key not in STANC_OPTS:
|
||||
raise ValueError(f'unknown stanc compiler option: {key}')
|
||||
elif key == 'include-paths':
|
||||
paths = val
|
||||
if isinstance(val, str):
|
||||
paths = val.split(',')
|
||||
elif not isinstance(val, list):
|
||||
raise ValueError(
|
||||
'Invalid include-paths, expecting list or '
|
||||
f'string, found type: {type(val)}.'
|
||||
)
|
||||
elif key == 'use-opencl':
|
||||
if self._cpp_options is None:
|
||||
self._cpp_options = {'STAN_OPENCL': 'TRUE'}
|
||||
else:
|
||||
self._cpp_options['STAN_OPENCL'] = 'TRUE'
|
||||
elif key.startswith('O'):
|
||||
if has_o_flag:
|
||||
get_logger().warning(
|
||||
'More than one of (O, O1, O2, Oexperimental)'
|
||||
'optimizations passed. Only the last one will'
|
||||
'be used'
|
||||
)
|
||||
else:
|
||||
has_o_flag = True
|
||||
|
||||
for opt in ignore:
|
||||
del self._stanc_options[opt]
|
||||
if paths is not None:
|
||||
bad_paths = [dir for dir in paths if not os.path.exists(dir)]
|
||||
if any(bad_paths):
|
||||
raise ValueError(
|
||||
'invalid include paths: {}'.format(', '.join(bad_paths))
|
||||
)
|
||||
|
||||
self._stanc_options['include-paths'] = [
|
||||
os.path.abspath(os.path.expanduser(path)) for path in paths
|
||||
]
|
||||
|
||||
def validate_cpp_opts(self) -> None:
|
||||
"""
|
||||
Check cpp compiler args.
|
||||
Raise ValueError if bad config is found.
|
||||
"""
|
||||
if self._cpp_options is None:
|
||||
return
|
||||
for key in ['OPENCL_DEVICE_ID', 'OPENCL_PLATFORM_ID']:
|
||||
if key in self._cpp_options:
|
||||
self._cpp_options['STAN_OPENCL'] = 'TRUE'
|
||||
val = self._cpp_options[key]
|
||||
if not isinstance(val, int) or val < 0:
|
||||
raise ValueError(
|
||||
f'{key} must be a non-negative integer value,'
|
||||
f' found {val}.'
|
||||
)
|
||||
|
||||
def validate_user_header(self) -> None:
|
||||
"""
|
||||
User header exists.
|
||||
Raise ValueError if bad config is found.
|
||||
"""
|
||||
if self._user_header != "":
|
||||
if not (
|
||||
os.path.exists(self._user_header)
|
||||
and os.path.isfile(self._user_header)
|
||||
):
|
||||
raise ValueError(
|
||||
f"User header file {self._user_header} cannot be found"
|
||||
)
|
||||
if self._user_header[-4:] != '.hpp':
|
||||
raise ValueError(
|
||||
f"Header file must end in .hpp, got {self._user_header}"
|
||||
)
|
||||
if "allow-undefined" not in self._stanc_options:
|
||||
self._stanc_options["allow-undefined"] = True
|
||||
# set full path
|
||||
self._user_header = os.path.abspath(self._user_header)
|
||||
|
||||
if ' ' in self._user_header:
|
||||
raise ValueError(
|
||||
"User header must be in a location with no spaces in path!"
|
||||
)
|
||||
|
||||
if (
|
||||
'USER_HEADER' in self._cpp_options
|
||||
and self._user_header != self._cpp_options['USER_HEADER']
|
||||
):
|
||||
raise ValueError(
|
||||
"Disagreement in user_header C++ options found!\n"
|
||||
f"{self._user_header}, {self._cpp_options['USER_HEADER']}"
|
||||
)
|
||||
|
||||
self._cpp_options['USER_HEADER'] = self._user_header
|
||||
|
||||
def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
|
||||
"""Adds options to existing set of compiler options."""
|
||||
if new_opts.stanc_options is not None:
|
||||
if self._stanc_options is None:
|
||||
self._stanc_options = new_opts.stanc_options
|
||||
else:
|
||||
for key, val in new_opts.stanc_options.items():
|
||||
if key == 'include-paths':
|
||||
if isinstance(val, Iterable) and not isinstance(
|
||||
val, str
|
||||
):
|
||||
for path in val:
|
||||
self.add_include_path(str(path))
|
||||
else:
|
||||
self.add_include_path(str(val))
|
||||
else:
|
||||
self._stanc_options[key] = val
|
||||
if new_opts.cpp_options is not None:
|
||||
for key, val in new_opts.cpp_options.items():
|
||||
self._cpp_options[key] = val
|
||||
if new_opts._user_header != '' and self._user_header == '':
|
||||
self._user_header = new_opts._user_header
|
||||
|
||||
def add_include_path(self, path: str) -> None:
|
||||
"""Adds include path to existing set of compiler options."""
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
if 'include-paths' not in self._stanc_options:
|
||||
self._stanc_options['include-paths'] = [path]
|
||||
elif path not in self._stanc_options['include-paths']:
|
||||
self._stanc_options['include-paths'].append(path)
|
||||
|
||||
def compose_stanc(self, filename_in_msg: Optional[str]) -> List[str]:
|
||||
opts = []
|
||||
|
||||
if filename_in_msg is not None:
|
||||
opts.append(f'--filename-in-msg={filename_in_msg}')
|
||||
|
||||
if self._stanc_options is not None and len(self._stanc_options) > 0:
|
||||
for key, val in self._stanc_options.items():
|
||||
if key == 'include-paths':
|
||||
opts.append(
|
||||
'--include-paths='
|
||||
+ ','.join(
|
||||
(
|
||||
Path(p).as_posix()
|
||||
for p in self._stanc_options['include-paths']
|
||||
)
|
||||
)
|
||||
)
|
||||
elif key == 'name':
|
||||
opts.append(f'--name={val}')
|
||||
else:
|
||||
opts.append(f'--{key}')
|
||||
return opts
|
||||
|
||||
def compose(self, filename_in_msg: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Format makefile options as list of strings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename_in_msg : str, optional
|
||||
filename to be displayed in stanc3 error messages
|
||||
(if different from actual filename on disk), by default None
|
||||
"""
|
||||
opts = [
|
||||
'STANCFLAGS+=' + flag.replace(" ", "\\ ")
|
||||
for flag in self.compose_stanc(filename_in_msg)
|
||||
]
|
||||
if self._cpp_options is not None and len(self._cpp_options) > 0:
|
||||
for key, val in self._cpp_options.items():
|
||||
opts.append(f'{key}={val}')
|
||||
return opts
|
||||
|
||||
|
||||
def src_info(
|
||||
stan_file: str, compiler_options: CompilerOptions
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get source info for Stan program file.
|
||||
|
||||
This function is used in the implementation of
|
||||
:meth:`CmdStanModel.src_info`, and should not be called directly.
|
||||
"""
|
||||
cmd = (
|
||||
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
|
||||
# handle include-paths, allow-undefined etc
|
||||
+ compiler_options.compose_stanc(None)
|
||||
+ ['--info', str(stan_file)]
|
||||
)
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if proc.returncode:
|
||||
raise ValueError(
|
||||
f"Failed to get source info for Stan model "
|
||||
f"'{stan_file}'. Console:\n{proc.stderr}"
|
||||
)
|
||||
result: Dict[str, Any] = json.loads(proc.stdout)
|
||||
return result
|
||||
|
||||
|
||||
def compile_stan_file(
|
||||
src: Union[str, Path],
|
||||
force: bool = False,
|
||||
stanc_options: Optional[Dict[str, Any]] = None,
|
||||
cpp_options: Optional[Dict[str, Any]] = None,
|
||||
user_header: OptionalPath = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compile the given Stan program file. Translates the Stan code to
|
||||
C++, then calls the C++ compiler.
|
||||
|
||||
By default, this function compares the timestamps on the source and
|
||||
executable files; if the executable is newer than the source file, it
|
||||
will not recompile the file, unless argument ``force`` is ``True``
|
||||
or unless the compiler options have been changed.
|
||||
|
||||
:param src: Path to Stan program file.
|
||||
|
||||
:param force: When ``True``, always compile, even if the executable file
|
||||
is newer than the source file. Used for Stan models which have
|
||||
``#include`` directives in order to force recompilation when changes
|
||||
are made to the included files.
|
||||
|
||||
:param stanc_options: Options for stanc compiler.
|
||||
:param cpp_options: Options for C++ compiler.
|
||||
:param user_header: A path to a header file to include during C++
|
||||
compilation.
|
||||
"""
|
||||
|
||||
src = Path(src).resolve()
|
||||
if not src.exists():
|
||||
raise ValueError(f'stan file does not exist: {src}')
|
||||
|
||||
compiler_options = CompilerOptions(
|
||||
stanc_options=stanc_options,
|
||||
cpp_options=cpp_options,
|
||||
user_header=user_header,
|
||||
)
|
||||
compiler_options.validate()
|
||||
|
||||
exe_target = src.with_suffix(EXTENSION)
|
||||
if exe_target.exists():
|
||||
exe_time = os.path.getmtime(exe_target)
|
||||
included_files = [src]
|
||||
included_files.extend(
|
||||
src_info(str(src), compiler_options).get('included_files', [])
|
||||
)
|
||||
out_of_date = any(
|
||||
os.path.getmtime(included_file) > exe_time
|
||||
for included_file in included_files
|
||||
)
|
||||
if not out_of_date and not force:
|
||||
get_logger().debug('found newer exe file, not recompiling')
|
||||
return str(exe_target)
|
||||
|
||||
compilation_failed = False
|
||||
# if target path has spaces or special characters, use a copy in a
|
||||
# temporary directory (GNU-Make constraint)
|
||||
with SanitizedOrTmpFilePath(str(src)) as (stan_file, is_copied):
|
||||
exe_file = os.path.splitext(stan_file)[0] + EXTENSION
|
||||
|
||||
hpp_file = os.path.splitext(exe_file)[0] + '.hpp'
|
||||
if os.path.exists(hpp_file):
|
||||
os.remove(hpp_file)
|
||||
if os.path.exists(exe_file):
|
||||
get_logger().debug('Removing %s', exe_file)
|
||||
os.remove(exe_file)
|
||||
|
||||
get_logger().info(
|
||||
'compiling stan file %s to exe file %s',
|
||||
stan_file,
|
||||
exe_target,
|
||||
)
|
||||
|
||||
make = os.getenv(
|
||||
'MAKE',
|
||||
'make' if platform.system() != 'Windows' else 'mingw32-make',
|
||||
)
|
||||
cmd = [make]
|
||||
cmd.extend(compiler_options.compose(filename_in_msg=src.name))
|
||||
cmd.append(Path(exe_file).as_posix())
|
||||
|
||||
sout = io.StringIO()
|
||||
try:
|
||||
do_command(cmd=cmd, cwd=cmdstan_path(), fd_out=sout)
|
||||
except RuntimeError as e:
|
||||
sout.write(f'\n{str(e)}\n')
|
||||
compilation_failed = True
|
||||
finally:
|
||||
console = sout.getvalue()
|
||||
|
||||
get_logger().debug('Console output:\n%s', console)
|
||||
if not compilation_failed:
|
||||
if is_copied:
|
||||
shutil.copy(exe_file, exe_target)
|
||||
get_logger().info('compiled model executable: %s', exe_target)
|
||||
if 'Warning' in console:
|
||||
lines = console.split('\n')
|
||||
warnings = [x for x in lines if x.startswith('Warning')]
|
||||
get_logger().warning(
|
||||
'Stan compiler has produced %d warnings:',
|
||||
len(warnings),
|
||||
)
|
||||
get_logger().warning(console)
|
||||
if compilation_failed:
|
||||
if 'PCH' in console or 'precompiled header' in console:
|
||||
get_logger().warning(
|
||||
"CmdStan's precompiled header (PCH) files "
|
||||
"may need to be rebuilt."
|
||||
"Please run cmdstanpy.rebuild_cmdstan().\n"
|
||||
"If the issue persists please open a bug report"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to compile Stan model '{src}'. " f"Console:\n{console}"
|
||||
)
|
||||
return str(exe_target)
|
||||
|
||||
|
||||
def format_stan_file(
|
||||
stan_file: Union[str, os.PathLike],
|
||||
*,
|
||||
overwrite_file: bool = False,
|
||||
canonicalize: Union[bool, str, Iterable[str]] = False,
|
||||
max_line_length: int = 78,
|
||||
backup: bool = True,
|
||||
stanc_options: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run stanc's auto-formatter on the model code. Either saves directly
|
||||
back to the file or prints for inspection
|
||||
|
||||
:param stan_file: Path to Stan program file.
|
||||
:param overwrite_file: If True, save the updated code to disk, rather
|
||||
than printing it. By default False
|
||||
:param canonicalize: Whether or not the compiler should 'canonicalize'
|
||||
the Stan model, removing things like deprecated syntax. Default is
|
||||
False. If True, all canonicalizations are run. If it is a list of
|
||||
strings, those options are passed to stanc (new in Stan 2.29)
|
||||
:param max_line_length: Set the wrapping point for the formatter. The
|
||||
default value is 78, which wraps most lines by the 80th character.
|
||||
:param backup: If True, create a stanfile.bak backup before
|
||||
writing to the file. Only disable this if you're sure you have other
|
||||
copies of the file or are using a version control system like Git.
|
||||
:param stanc_options: Additional options to pass to the stanc compiler.
|
||||
"""
|
||||
stan_file = Path(stan_file).resolve()
|
||||
|
||||
if not stan_file.exists():
|
||||
raise ValueError(f'File does not exist: {stan_file}')
|
||||
|
||||
try:
|
||||
cmd = (
|
||||
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
|
||||
# handle include-paths, allow-undefined etc
|
||||
+ CompilerOptions(stanc_options=stanc_options).compose_stanc(None)
|
||||
+ [str(stan_file)]
|
||||
)
|
||||
|
||||
if canonicalize:
|
||||
if cmdstan_version_before(2, 29):
|
||||
if isinstance(canonicalize, bool):
|
||||
cmd.append('--print-canonical')
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid arguments passed for current CmdStan"
|
||||
+ " version({})\n".format(
|
||||
cmdstan_version() or "Unknown"
|
||||
)
|
||||
+ "--canonicalize requires 2.29 or higher"
|
||||
)
|
||||
else:
|
||||
if isinstance(canonicalize, str):
|
||||
cmd.append('--canonicalize=' + canonicalize)
|
||||
elif isinstance(canonicalize, Iterable):
|
||||
cmd.append('--canonicalize=' + ','.join(canonicalize))
|
||||
else:
|
||||
cmd.append('--print-canonical')
|
||||
|
||||
# before 2.29, having both --print-canonical
|
||||
# and --auto-format printed twice
|
||||
if not (cmdstan_version_before(2, 29) and canonicalize):
|
||||
cmd.append('--auto-format')
|
||||
|
||||
if not cmdstan_version_before(2, 29):
|
||||
cmd.append(f'--max-line-length={max_line_length}')
|
||||
elif max_line_length != 78:
|
||||
raise ValueError(
|
||||
"Invalid arguments passed for current CmdStan version"
|
||||
+ " ({})\n".format(cmdstan_version() or "Unknown")
|
||||
+ "--max-line-length requires 2.29 or higher"
|
||||
)
|
||||
|
||||
out = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
if out.stderr:
|
||||
get_logger().warning(out.stderr)
|
||||
result = out.stdout
|
||||
if overwrite_file:
|
||||
if result:
|
||||
if backup:
|
||||
shutil.copyfile(
|
||||
stan_file,
|
||||
str(stan_file)
|
||||
+ '.bak-'
|
||||
+ datetime.now().strftime("%Y%m%d%H%M%S"),
|
||||
)
|
||||
stan_file.write_text(result)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
except (ValueError, RuntimeError) as e:
|
||||
raise RuntimeError("Stanc formatting failed") from e
|
||||
711
.venv/lib/python3.12/site-packages/cmdstanpy/install_cmdstan.py
Normal file
711
.venv/lib/python3.12/site-packages/cmdstanpy/install_cmdstan.py
Normal file
@ -0,0 +1,711 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download and install a CmdStan release from GitHub.
|
||||
Downloads the release tar.gz file to temporary storage.
|
||||
Retries GitHub requests in order to allow for transient network outages.
|
||||
Builds CmdStan executables and tests the compiler by building
|
||||
example model ``bernoulli.stan``.
|
||||
|
||||
Optional command line arguments:
|
||||
-i, --interactive: flag, when specified ignore other arguments and
|
||||
ask user for settings on STDIN
|
||||
-v, --version <release> : version, defaults to latest release version
|
||||
-d, --dir <path> : install directory, defaults to '$HOME/.cmdstan
|
||||
--overwrite: flag, when specified re-installs existing version
|
||||
--progress: flag, when specified show progress bar for CmdStan download
|
||||
--verbose: flag, when specified prints output from CmdStan build process
|
||||
--cores: int, number of cores to use when building, defaults to 1
|
||||
-c, --compiler : flag, add C++ compiler to path (Windows only)
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from cmdstanpy import _DOT_CMDSTAN
|
||||
from cmdstanpy.utils import (
|
||||
cmdstan_path,
|
||||
do_command,
|
||||
pushd,
|
||||
validate_dir,
|
||||
wrap_url_progress_hook,
|
||||
)
|
||||
from cmdstanpy.utils.cmdstan import get_download_url
|
||||
|
||||
from . import progress as progbar
|
||||
|
||||
if sys.version_info >= (3, 8) or TYPE_CHECKING:
|
||||
# mypy only knows about the new built-in cached_property
|
||||
from functools import cached_property
|
||||
else:
|
||||
# on older Python versions, this is the recommended
|
||||
# way to get the same effect
|
||||
from functools import lru_cache
|
||||
|
||||
def cached_property(fun):
|
||||
return property(lru_cache(maxsize=None)(fun))
|
||||
|
||||
|
||||
try:
|
||||
# on MacOS and Linux, importing this
|
||||
# improves the UX of the input() function
|
||||
import readline
|
||||
|
||||
# dummy statement to use import for flake8/pylint
|
||||
_ = readline.__doc__
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class CmdStanRetrieveError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class CmdStanInstallError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
return platform.system() == 'Windows'
|
||||
|
||||
|
||||
MAKE = os.getenv('MAKE', 'make' if not is_windows() else 'mingw32-make')
|
||||
EXTENSION = '.exe' if is_windows() else ''
|
||||
|
||||
|
||||
def get_headers() -> Dict[str, str]:
|
||||
"""Create headers dictionary."""
|
||||
headers = {}
|
||||
GITHUB_PAT = os.environ.get("GITHUB_PAT") # pylint:disable=invalid-name
|
||||
if GITHUB_PAT is not None:
|
||||
headers["Authorization"] = "token {}".format(GITHUB_PAT)
|
||||
return headers
|
||||
|
||||
|
||||
def latest_version() -> str:
|
||||
"""Report latest CmdStan release version."""
|
||||
url = 'https://api.github.com/repos/stan-dev/cmdstan/releases/latest'
|
||||
request = urllib.request.Request(url, headers=get_headers())
|
||||
for i in range(6):
|
||||
try:
|
||||
response = urllib.request.urlopen(request).read()
|
||||
break
|
||||
except urllib.error.URLError as e:
|
||||
print('Cannot connect to github.')
|
||||
print(e)
|
||||
if i < 5:
|
||||
print('retry ({}/5)'.format(i + 1))
|
||||
sleep(1)
|
||||
continue
|
||||
raise CmdStanRetrieveError(
|
||||
'Cannot connect to CmdStan github repo.'
|
||||
) from e
|
||||
content = json.loads(response.decode('utf-8'))
|
||||
tag = content['tag_name']
|
||||
match = re.search(r'v?(.+)', tag)
|
||||
if match is not None:
|
||||
tag = match.group(1)
|
||||
return tag # type: ignore
|
||||
|
||||
|
||||
def home_cmdstan() -> str:
|
||||
return os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class InstallationSettings:
|
||||
"""
|
||||
A static installation settings object
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
progress: bool = False,
|
||||
verbose: bool = False,
|
||||
overwrite: bool = False,
|
||||
cores: int = 1,
|
||||
compiler: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.version = version if version else latest_version()
|
||||
self.dir = dir if dir else home_cmdstan()
|
||||
self.progress = progress
|
||||
self.verbose = verbose
|
||||
self.overwrite = overwrite
|
||||
self.cores = cores
|
||||
self.compiler = compiler and is_windows()
|
||||
|
||||
_ = kwargs # ignore all other inputs.
|
||||
# Useful if initialized from a dictionary like **dict
|
||||
|
||||
|
||||
def yes_no(answer: str, default: bool) -> bool:
|
||||
answer = answer.lower()
|
||||
if answer in ('y', 'yes'):
|
||||
return True
|
||||
if answer in ('n', 'no'):
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
class InteractiveSettings:
|
||||
"""
|
||||
Installation settings provided on-demand in an interactive format.
|
||||
|
||||
This provides the same set of properties as the ``InstallationSettings``
|
||||
object, but rather than them being fixed by the constructor the user is
|
||||
asked for input whenever they are accessed for the first time.
|
||||
"""
|
||||
|
||||
@cached_property
|
||||
def version(self) -> str:
|
||||
latest = latest_version()
|
||||
print("Which version would you like to install?")
|
||||
print(f"Default: {latest}")
|
||||
answer = input("Type version or hit enter to continue: ")
|
||||
return answer if answer else latest
|
||||
|
||||
@cached_property
|
||||
def dir(self) -> str:
|
||||
directory = home_cmdstan()
|
||||
print("Where would you like to install CmdStan?")
|
||||
print(f"Default: {directory}")
|
||||
answer = input("Type full path or hit enter to continue: ")
|
||||
return os.path.expanduser(answer) if answer else directory
|
||||
|
||||
@cached_property
|
||||
def progress(self) -> bool:
|
||||
print("Show installation progress bars?")
|
||||
print("Default: y")
|
||||
answer = input("[y/n]: ")
|
||||
return yes_no(answer, True)
|
||||
|
||||
@cached_property
|
||||
def verbose(self) -> bool:
|
||||
print("Show verbose output of the installation process?")
|
||||
print("Default: n")
|
||||
answer = input("[y/n]: ")
|
||||
return yes_no(answer, False)
|
||||
|
||||
@cached_property
|
||||
def overwrite(self) -> bool:
|
||||
print("Overwrite existing CmdStan installation?")
|
||||
print("Default: n")
|
||||
answer = input("[y/n]: ")
|
||||
return yes_no(answer, False)
|
||||
|
||||
@cached_property
|
||||
def compiler(self) -> bool:
|
||||
if not is_windows():
|
||||
return False
|
||||
print("Would you like to install the RTools40 C++ toolchain?")
|
||||
print("A C++ toolchain is required for CmdStan.")
|
||||
print(
|
||||
"If you are not sure if you need the toolchain or not, "
|
||||
"the most likely case is you do need it, and should answer 'y'."
|
||||
)
|
||||
print("Default: n")
|
||||
answer = input("[y/n]: ")
|
||||
return yes_no(answer, False)
|
||||
|
||||
@cached_property
|
||||
def cores(self) -> int:
|
||||
max_cpus = os.cpu_count() or 1
|
||||
print(
|
||||
"How many CPU cores would you like to use for installing "
|
||||
"and compiling CmdStan?"
|
||||
)
|
||||
print(f"Default: 1, Max: {max_cpus}")
|
||||
answer = input("Enter a number or hit enter to continue: ")
|
||||
try:
|
||||
return min(max_cpus, max(int(answer), 1))
|
||||
except ValueError:
|
||||
return 1
|
||||
|
||||
|
||||
def clean_all(verbose: bool = False) -> None:
|
||||
"""
|
||||
Run `make clean-all` in the current directory (must be a cmdstan library).
|
||||
|
||||
:param verbose: Boolean value; when ``True``, show output from make command.
|
||||
"""
|
||||
cmd = [MAKE, 'clean-all']
|
||||
try:
|
||||
if verbose:
|
||||
do_command(cmd)
|
||||
else:
|
||||
do_command(cmd, fd_out=None)
|
||||
|
||||
except RuntimeError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise CmdStanInstallError(f'Command "make clean-all" failed\n{str(e)}')
|
||||
|
||||
|
||||
def build(verbose: bool = False, progress: bool = True, cores: int = 1) -> None:
|
||||
"""
|
||||
Run command ``make build`` in the current directory, which must be
|
||||
the home directory of a CmdStan version (or GitHub repo).
|
||||
By default, displays a progress bar which tracks make command outputs.
|
||||
If argument ``verbose=True``, instead of a progress bar, streams
|
||||
make command outputs to sys.stdout. When both ``verbose`` and ``progress``
|
||||
are ``False``, runs silently.
|
||||
|
||||
:param verbose: Boolean value; when ``True``, show output from make command.
|
||||
Default is ``False``.
|
||||
:param progress: Boolean value; when ``True`` display progress progress bar.
|
||||
Default is ``True``.
|
||||
:param cores: Integer, number of cores to use in the ``make`` command.
|
||||
Default is 1 core.
|
||||
"""
|
||||
cmd = [MAKE, 'build', f'-j{cores}']
|
||||
try:
|
||||
if verbose:
|
||||
do_command(cmd)
|
||||
elif progress and progbar.allow_show_progress():
|
||||
progress_hook: Any = _wrap_build_progress_hook()
|
||||
do_command(cmd, fd_out=None, pbar=progress_hook)
|
||||
else:
|
||||
do_command(cmd, fd_out=None)
|
||||
|
||||
except RuntimeError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise CmdStanInstallError(f'Command "make build" failed\n{str(e)}')
|
||||
if not os.path.exists(os.path.join('bin', 'stansummary' + EXTENSION)):
|
||||
raise CmdStanInstallError(
|
||||
f'bin/stansummary{EXTENSION} not found'
|
||||
', please rebuild or report a bug!'
|
||||
)
|
||||
if not os.path.exists(os.path.join('bin', 'diagnose' + EXTENSION)):
|
||||
raise CmdStanInstallError(
|
||||
f'bin/stansummary{EXTENSION} not found'
|
||||
', please rebuild or report a bug!'
|
||||
)
|
||||
|
||||
if is_windows():
|
||||
# Add tbb to the $PATH on Windows
|
||||
libtbb = os.path.join(
|
||||
os.getcwd(), 'stan', 'lib', 'stan_math', 'lib', 'tbb'
|
||||
)
|
||||
os.environ['PATH'] = ';'.join(
|
||||
list(
|
||||
OrderedDict.fromkeys(
|
||||
[libtbb] + os.environ.get('PATH', '').split(';')
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@progbar.wrap_callback
|
||||
def _wrap_build_progress_hook() -> Optional[Callable[[str], None]]:
|
||||
"""Sets up tqdm callback for CmdStan sampler console msgs."""
|
||||
pad = ' ' * 20
|
||||
msgs_expected = 150 # hack: 2.27 make build send ~140 msgs to console
|
||||
pbar: tqdm = tqdm(
|
||||
total=msgs_expected,
|
||||
bar_format="{desc} ({elapsed}) | {bar} | {postfix[0][value]}",
|
||||
postfix=[{"value": f'Building CmdStan {pad}'}],
|
||||
colour='blue',
|
||||
desc='',
|
||||
position=0,
|
||||
)
|
||||
|
||||
def build_progress_hook(line: str) -> None:
|
||||
if line.startswith('--- CmdStan'):
|
||||
pbar.set_description('Done')
|
||||
pbar.postfix[0]["value"] = line
|
||||
pbar.update(msgs_expected - pbar.n)
|
||||
pbar.close()
|
||||
else:
|
||||
if line.startswith('--'):
|
||||
pbar.postfix[0]["value"] = line
|
||||
else:
|
||||
pbar.postfix[0]["value"] = f'{line[:8]} ... {line[-20:]}'
|
||||
pbar.set_description('Compiling')
|
||||
pbar.update(1)
|
||||
|
||||
return build_progress_hook
|
||||
|
||||
|
||||
def compile_example(verbose: bool = False) -> None:
|
||||
"""
|
||||
Compile the example model.
|
||||
The current directory must be a cmdstan installation, i.e.,
|
||||
contains the makefile, Stanc compiler, and all libraries.
|
||||
|
||||
:param verbose: Boolean value; when ``True``, show output from make command.
|
||||
"""
|
||||
path = Path('examples', 'bernoulli', 'bernoulli').with_suffix(EXTENSION)
|
||||
if path.is_file():
|
||||
path.unlink()
|
||||
|
||||
cmd = [MAKE, path.as_posix()]
|
||||
try:
|
||||
if verbose:
|
||||
do_command(cmd)
|
||||
else:
|
||||
do_command(cmd, fd_out=None)
|
||||
except RuntimeError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise CmdStanInstallError(f'Command "{" ".join(cmd)}" failed:\n{e}')
|
||||
|
||||
if not path.is_file():
|
||||
raise CmdStanInstallError("Failed to generate example binary")
|
||||
|
||||
|
||||
def rebuild_cmdstan(
|
||||
verbose: bool = False, progress: bool = True, cores: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Rebuilds the existing CmdStan installation.
|
||||
This assumes CmdStan has already been installed,
|
||||
though it need not be installed via CmdStanPy for
|
||||
this function to work.
|
||||
|
||||
:param verbose: Boolean value; when ``True``, show output from make command.
|
||||
Default is ``False``.
|
||||
:param progress: Boolean value; when ``True`` display progress progress bar.
|
||||
Default is ``True``.
|
||||
:param cores: Integer, number of cores to use in the ``make`` command.
|
||||
Default is 1 core.
|
||||
"""
|
||||
try:
|
||||
with pushd(cmdstan_path()):
|
||||
clean_all(verbose)
|
||||
build(verbose, progress, cores)
|
||||
compile_example(verbose)
|
||||
except ValueError as e:
|
||||
raise CmdStanInstallError(
|
||||
"Failed to rebuild CmdStan. Are you sure it is installed?"
|
||||
) from e
|
||||
|
||||
|
||||
def install_version(
|
||||
cmdstan_version: str,
|
||||
overwrite: bool = False,
|
||||
verbose: bool = False,
|
||||
progress: bool = True,
|
||||
cores: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Build specified CmdStan version by spawning subprocesses to
|
||||
run the Make utility on the downloaded CmdStan release src files.
|
||||
Assumes that current working directory is parent of release dir.
|
||||
|
||||
:param cmdstan_version: CmdStan release, corresponds to release dirname.
|
||||
:param overwrite: when ``True``, run ``make clean-all`` before building.
|
||||
:param verbose: Boolean value; when ``True``, show output from make command.
|
||||
"""
|
||||
with pushd(cmdstan_version):
|
||||
print(
|
||||
'Building version {}, may take several minutes, '
|
||||
'depending on your system.'.format(cmdstan_version)
|
||||
)
|
||||
if overwrite and os.path.exists('.'):
|
||||
print(
|
||||
'Overwrite requested, remove existing build of version '
|
||||
'{}'.format(cmdstan_version)
|
||||
)
|
||||
clean_all(verbose)
|
||||
print('Rebuilding version {}'.format(cmdstan_version))
|
||||
build(verbose, progress=progress, cores=cores)
|
||||
print('Installed {}'.format(cmdstan_version))
|
||||
|
||||
|
||||
def is_version_available(version: str) -> bool:
|
||||
if 'git:' in version:
|
||||
return True # no good way in general to check if a git tag exists
|
||||
|
||||
is_available = True
|
||||
url = get_download_url(version)
|
||||
for i in range(6):
|
||||
try:
|
||||
urllib.request.urlopen(url)
|
||||
except urllib.error.HTTPError as err:
|
||||
print(f'Release {version} is unavailable from URL {url}')
|
||||
print(f'HTTPError: {err.code}')
|
||||
is_available = False
|
||||
break
|
||||
except urllib.error.URLError as e:
|
||||
if i < 5:
|
||||
print(
|
||||
'checking version {} availability, retry ({}/5)'.format(
|
||||
version, i + 1
|
||||
)
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
print('Release {} is unavailable from URL {}'.format(version, url))
|
||||
print('URLError: {}'.format(e.reason))
|
||||
is_available = False
|
||||
return is_available
|
||||
|
||||
|
||||
def retrieve_version(version: str, progress: bool = True) -> None:
|
||||
"""Download specified CmdStan version."""
|
||||
if version is None or version == '':
|
||||
raise ValueError('Argument "version" unspecified.')
|
||||
|
||||
if 'git:' in version:
|
||||
tag = version.split(':')[1]
|
||||
tag_folder = version.replace(':', '-').replace('/', '_')
|
||||
print(f"Cloning CmdStan branch '{tag}' from stan-dev/cmdstan on GitHub")
|
||||
do_command(
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
'--depth',
|
||||
'1',
|
||||
'--branch',
|
||||
tag,
|
||||
'--recursive',
|
||||
'--shallow-submodules',
|
||||
'https://github.com/stan-dev/cmdstan.git',
|
||||
f'cmdstan-{tag_folder}',
|
||||
]
|
||||
)
|
||||
return
|
||||
|
||||
print('Downloading CmdStan version {}'.format(version))
|
||||
url = get_download_url(version)
|
||||
for i in range(6): # always retry to allow for transient URLErrors
|
||||
try:
|
||||
if progress and progbar.allow_show_progress():
|
||||
progress_hook: Optional[
|
||||
Callable[[int, int, int], None]
|
||||
] = wrap_url_progress_hook()
|
||||
else:
|
||||
progress_hook = None
|
||||
file_tmp, _ = urllib.request.urlretrieve(
|
||||
url, filename=None, reporthook=progress_hook
|
||||
)
|
||||
break
|
||||
except urllib.error.HTTPError as e:
|
||||
raise CmdStanRetrieveError(
|
||||
'HTTPError: {}\n'
|
||||
'Version {} not available from github.com.'.format(
|
||||
e.code, version
|
||||
)
|
||||
) from e
|
||||
except urllib.error.URLError as e:
|
||||
print(
|
||||
'Failed to download CmdStan version {} from github.com'.format(
|
||||
version
|
||||
)
|
||||
)
|
||||
print(e)
|
||||
if i < 5:
|
||||
print('retry ({}/5)'.format(i + 1))
|
||||
sleep(1)
|
||||
continue
|
||||
print('Version {} not available from github.com.'.format(version))
|
||||
raise CmdStanRetrieveError(
|
||||
'Version {} not available from github.com.'.format(version)
|
||||
) from e
|
||||
print('Download successful, file: {}'.format(file_tmp))
|
||||
try:
|
||||
print('Extracting distribution')
|
||||
tar = tarfile.open(file_tmp)
|
||||
first = tar.next()
|
||||
if first is not None:
|
||||
top_dir = first.name
|
||||
else:
|
||||
top_dir = ''
|
||||
cmdstan_dir = f'cmdstan-{version}'
|
||||
if top_dir != cmdstan_dir:
|
||||
raise CmdStanInstallError(
|
||||
'tarfile should contain top-level dir {},'
|
||||
'but found dir {} instead.'.format(cmdstan_dir, top_dir)
|
||||
)
|
||||
target = os.getcwd()
|
||||
if is_windows():
|
||||
# fixes long-path limitation on Windows
|
||||
target = r'\\?\{}'.format(target)
|
||||
|
||||
if progress and progbar.allow_show_progress():
|
||||
for member in tqdm(
|
||||
iterable=tar.getmembers(),
|
||||
total=len(tar.getmembers()),
|
||||
colour='blue',
|
||||
leave=False,
|
||||
):
|
||||
tar.extract(member=member)
|
||||
else:
|
||||
tar.extractall()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
raise CmdStanInstallError(
|
||||
f'Failed to unpack file {file_tmp}, error:\n\t{str(e)}'
|
||||
) from e
|
||||
finally:
|
||||
tar.close()
|
||||
print(f'Unpacked download as {cmdstan_dir}')
|
||||
|
||||
|
||||
def run_compiler_install(dir: str, verbose: bool, progress: bool) -> None:
|
||||
from .install_cxx_toolchain import is_installed as _is_installed_cxx
|
||||
from .install_cxx_toolchain import run_rtools_install as _main_cxx
|
||||
from .utils import cxx_toolchain_path
|
||||
|
||||
compiler_found = False
|
||||
rtools40_home = os.environ.get('RTOOLS40_HOME')
|
||||
for cxx_loc in ([rtools40_home] if rtools40_home is not None else []) + [
|
||||
home_cmdstan(),
|
||||
os.path.join(os.path.abspath("/"), "RTools40"),
|
||||
os.path.join(os.path.abspath("/"), "RTools"),
|
||||
os.path.join(os.path.abspath("/"), "RTools35"),
|
||||
os.path.join(os.path.abspath("/"), "RBuildTools"),
|
||||
]:
|
||||
for cxx_version in ['40', '35']:
|
||||
if _is_installed_cxx(cxx_loc, cxx_version):
|
||||
compiler_found = True
|
||||
break
|
||||
if compiler_found:
|
||||
break
|
||||
if not compiler_found:
|
||||
print('Installing RTools40')
|
||||
# copy argv and clear sys.argv
|
||||
_main_cxx(
|
||||
{
|
||||
'dir': dir,
|
||||
'progress': progress,
|
||||
'version': None,
|
||||
'verbose': verbose,
|
||||
}
|
||||
)
|
||||
cxx_version = '40'
|
||||
# Add toolchain to $PATH
|
||||
cxx_toolchain_path(cxx_version, dir)
|
||||
|
||||
|
||||
def run_install(args: Union[InteractiveSettings, InstallationSettings]) -> None:
|
||||
"""
|
||||
Run a (potentially interactive) installation
|
||||
"""
|
||||
validate_dir(args.dir)
|
||||
print('CmdStan install directory: {}'.format(args.dir))
|
||||
|
||||
# these accesses just 'warm up' the interactive install
|
||||
_ = args.progress
|
||||
_ = args.verbose
|
||||
|
||||
if args.compiler:
|
||||
run_compiler_install(args.dir, args.verbose, args.progress)
|
||||
|
||||
if 'git:' in args.version:
|
||||
tag = args.version.replace(':', '-').replace('/', '_')
|
||||
cmdstan_version = f'cmdstan-{tag}'
|
||||
else:
|
||||
cmdstan_version = f'cmdstan-{args.version}'
|
||||
with pushd(args.dir):
|
||||
already_installed = os.path.exists(cmdstan_version) and os.path.exists(
|
||||
os.path.join(
|
||||
cmdstan_version,
|
||||
'examples',
|
||||
'bernoulli',
|
||||
'bernoulli' + EXTENSION,
|
||||
)
|
||||
)
|
||||
if not already_installed or args.overwrite:
|
||||
if is_version_available(args.version):
|
||||
print('Installing CmdStan version: {}'.format(args.version))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Version {args.version} cannot be downloaded. '
|
||||
'Connection to GitHub failed. '
|
||||
'Check firewall settings or ensure this version exists.'
|
||||
)
|
||||
shutil.rmtree(cmdstan_version, ignore_errors=True)
|
||||
retrieve_version(args.version, args.progress)
|
||||
install_version(
|
||||
cmdstan_version=cmdstan_version,
|
||||
overwrite=already_installed and args.overwrite,
|
||||
verbose=args.verbose,
|
||||
progress=args.progress,
|
||||
cores=args.cores,
|
||||
)
|
||||
else:
|
||||
print('CmdStan version {} already installed'.format(args.version))
|
||||
|
||||
with pushd(cmdstan_version):
|
||||
print('Test model compilation')
|
||||
compile_example(args.verbose)
|
||||
|
||||
|
||||
def parse_cmdline_args() -> Dict[str, Any]:
|
||||
parser = argparse.ArgumentParser("install_cmdstan")
|
||||
parser.add_argument(
|
||||
'--interactive',
|
||||
'-i',
|
||||
action='store_true',
|
||||
help="Ignore other arguments and run the installation in "
|
||||
+ "interactive mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--version',
|
||||
'-v',
|
||||
help="version, defaults to latest release version. "
|
||||
"If git is installed, you can also specify a git tag or branch, "
|
||||
"e.g. git:develop",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dir', '-d', help="install directory, defaults to '$HOME/.cmdstan"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--overwrite',
|
||||
action='store_true',
|
||||
help="flag, when specified re-installs existing version",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help="flag, when specified prints output from CmdStan build process",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--progress',
|
||||
action='store_true',
|
||||
help="flag, when specified show progress bar for CmdStan download",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cores",
|
||||
default=1,
|
||||
type=int,
|
||||
help="number of cores to use while building",
|
||||
)
|
||||
if is_windows():
|
||||
# use compiler installed with install_cxx_toolchain
|
||||
# Install a new compiler if compiler not found
|
||||
# Search order is RTools40, RTools35
|
||||
parser.add_argument(
|
||||
'--compiler',
|
||||
'-c',
|
||||
dest='compiler',
|
||||
action='store_true',
|
||||
help="flag, add C++ compiler to path (Windows only)",
|
||||
)
|
||||
return vars(parser.parse_args(sys.argv[1:]))
|
||||
|
||||
|
||||
def __main__() -> None:
|
||||
args = parse_cmdline_args()
|
||||
if args.get('interactive', False):
|
||||
run_install(InteractiveSettings())
|
||||
else:
|
||||
run_install(InstallationSettings(**args))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
__main__()
|
||||
@ -0,0 +1,372 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download and install a C++ toolchain.
|
||||
Currently implemented platforms (platform.system)
|
||||
Windows: RTools 3.5, 4.0 (default)
|
||||
Darwin (macOS): Not implemented
|
||||
Linux: Not implemented
|
||||
Optional command line arguments:
|
||||
-v, --version : version, defaults to latest
|
||||
-d, --dir : install directory, defaults to '~/.cmdstan
|
||||
-s (--silent) : install with /VERYSILENT instead of /SILENT for RTools
|
||||
-m --no-make : don't install mingw32-make (Windows RTools 4.0 only)
|
||||
--progress : flag, when specified show progress bar for RTools download
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib.request
|
||||
from collections import OrderedDict
|
||||
from time import sleep
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from cmdstanpy import _DOT_CMDSTAN
|
||||
from cmdstanpy.utils import pushd, validate_dir, wrap_url_progress_hook
|
||||
|
||||
EXTENSION = '.exe' if platform.system() == 'Windows' else ''
|
||||
IS_64BITS = sys.maxsize > 2**32
|
||||
|
||||
|
||||
def usage() -> None:
|
||||
"""Print usage."""
|
||||
print(
|
||||
"""Arguments:
|
||||
-v (--version) :CmdStan version
|
||||
-d (--dir) : install directory
|
||||
-s (--silent) : install with /VERYSILENT instead of /SILENT for RTools
|
||||
-m (--no-make) : don't install mingw32-make (Windows RTools 4.0 only)
|
||||
--progress : flag, when specified show progress bar for RTools download
|
||||
-h (--help) : this message
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_config(dir: str, silent: bool) -> List[str]:
|
||||
"""Assemble config info."""
|
||||
config = []
|
||||
if platform.system() == 'Windows':
|
||||
_, dir = os.path.splitdrive(os.path.abspath(dir))
|
||||
if dir.startswith('\\'):
|
||||
dir = dir[1:]
|
||||
config = [
|
||||
'/SP-',
|
||||
'/VERYSILENT' if silent else '/SILENT',
|
||||
'/SUPPRESSMSGBOXES',
|
||||
'/CURRENTUSER',
|
||||
'LANG="English"',
|
||||
'/DIR="{}"'.format(dir),
|
||||
'/NOICONS',
|
||||
'/NORESTART',
|
||||
]
|
||||
return config
|
||||
|
||||
|
||||
def install_version(
|
||||
installation_dir: str,
|
||||
installation_file: str,
|
||||
version: str,
|
||||
silent: bool,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Install specified toolchain version."""
|
||||
with pushd('.'):
|
||||
print(
|
||||
'Installing the C++ toolchain: {}'.format(
|
||||
os.path.splitext(installation_file)[0]
|
||||
)
|
||||
)
|
||||
cmd = [installation_file]
|
||||
cmd.extend(get_config(installation_dir, silent))
|
||||
print(' '.join(cmd))
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=None,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=os.environ,
|
||||
)
|
||||
while proc.poll() is None:
|
||||
if proc.stdout:
|
||||
output = proc.stdout.readline().decode('utf-8').strip()
|
||||
if output and verbose:
|
||||
print(output, flush=True)
|
||||
_, stderr = proc.communicate()
|
||||
if proc.returncode:
|
||||
print('Installation failed: returncode={}'.format(proc.returncode))
|
||||
if stderr:
|
||||
print(stderr.decode('utf-8').strip())
|
||||
if is_installed(installation_dir, version):
|
||||
print('Installation files found at the installation location.')
|
||||
sys.exit(3)
|
||||
# check installation
|
||||
if is_installed(installation_dir, version):
|
||||
os.remove(installation_file)
|
||||
print('Installed {}'.format(os.path.splitext(installation_file)[0]))
|
||||
|
||||
|
||||
def install_mingw32_make(toolchain_loc: str, verbose: bool = False) -> None:
|
||||
"""Install mingw32-make for Windows RTools 4.0."""
|
||||
os.environ['PATH'] = ';'.join(
|
||||
list(
|
||||
OrderedDict.fromkeys(
|
||||
[
|
||||
os.path.join(
|
||||
toolchain_loc,
|
||||
'mingw_64' if IS_64BITS else 'mingw_32',
|
||||
'bin',
|
||||
),
|
||||
os.path.join(toolchain_loc, 'usr', 'bin'),
|
||||
]
|
||||
+ os.environ.get('PATH', '').split(';')
|
||||
)
|
||||
)
|
||||
)
|
||||
cmd = [
|
||||
'pacman',
|
||||
'-Sy',
|
||||
'mingw-w64-x86_64-make' if IS_64BITS else 'mingw-w64-i686-make',
|
||||
'--noconfirm',
|
||||
]
|
||||
with pushd('.'):
|
||||
print(' '.join(cmd))
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=None,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=os.environ,
|
||||
)
|
||||
while proc.poll() is None:
|
||||
if proc.stdout:
|
||||
output = proc.stdout.readline().decode('utf-8').strip()
|
||||
if output and verbose:
|
||||
print(output, flush=True)
|
||||
_, stderr = proc.communicate()
|
||||
if proc.returncode:
|
||||
print(
|
||||
'mingw32-make installation failed: returncode={}'.format(
|
||||
proc.returncode
|
||||
)
|
||||
)
|
||||
if stderr:
|
||||
print(stderr.decode('utf-8').strip())
|
||||
sys.exit(3)
|
||||
print('Installed mingw32-make.exe')
|
||||
|
||||
|
||||
def is_installed(toolchain_loc: str, version: str) -> bool:
|
||||
"""Returns True is toolchain is installed."""
|
||||
if platform.system() == 'Windows':
|
||||
if version in ['35', '3.5']:
|
||||
if not os.path.exists(os.path.join(toolchain_loc, 'bin')):
|
||||
return False
|
||||
return os.path.exists(
|
||||
os.path.join(
|
||||
toolchain_loc,
|
||||
'mingw_64' if IS_64BITS else 'mingw_32',
|
||||
'bin',
|
||||
'g++' + EXTENSION,
|
||||
)
|
||||
)
|
||||
elif version in ['40', '4.0', '4']:
|
||||
return os.path.exists(
|
||||
os.path.join(
|
||||
toolchain_loc,
|
||||
'mingw64' if IS_64BITS else 'mingw32',
|
||||
'bin',
|
||||
'g++' + EXTENSION,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def latest_version() -> str:
|
||||
"""Windows version hardcoded to 4.0."""
|
||||
if platform.system() == 'Windows':
|
||||
return '4.0'
|
||||
return ''
|
||||
|
||||
|
||||
def retrieve_toolchain(filename: str, url: str, progress: bool = True) -> None:
|
||||
"""Download toolchain from URL."""
|
||||
print('Downloading C++ toolchain: {}'.format(filename))
|
||||
for i in range(6):
|
||||
try:
|
||||
if progress:
|
||||
progress_hook = wrap_url_progress_hook()
|
||||
else:
|
||||
progress_hook = None
|
||||
_ = urllib.request.urlretrieve(
|
||||
url, filename=filename, reporthook=progress_hook
|
||||
)
|
||||
break
|
||||
except urllib.error.URLError as err:
|
||||
print('Failed to download C++ toolchain')
|
||||
print(err)
|
||||
if i < 5:
|
||||
print('retry ({}/5)'.format(i + 1))
|
||||
sleep(1)
|
||||
continue
|
||||
sys.exit(3)
|
||||
print('Download successful, file: {}'.format(filename))
|
||||
|
||||
|
||||
def normalize_version(version: str) -> str:
|
||||
"""Return maj.min part of version string."""
|
||||
if platform.system() == 'Windows':
|
||||
if version in ['4', '40']:
|
||||
version = '4.0'
|
||||
elif version == '35':
|
||||
version = '3.5'
|
||||
return version
|
||||
|
||||
|
||||
def get_toolchain_name() -> str:
|
||||
"""Return toolchain name."""
|
||||
if platform.system() == 'Windows':
|
||||
return 'RTools'
|
||||
return ''
|
||||
|
||||
|
||||
# TODO(2.0): drop 3.5 support
|
||||
def get_url(version: str) -> str:
|
||||
"""Return URL for toolchain."""
|
||||
url = ''
|
||||
if platform.system() == 'Windows':
|
||||
if version == '4.0':
|
||||
# pylint: disable=line-too-long
|
||||
if IS_64BITS:
|
||||
url = 'https://cran.r-project.org/bin/windows/Rtools/rtools40-x86_64.exe' # noqa: disable=E501
|
||||
else:
|
||||
url = 'https://cran.r-project.org/bin/windows/Rtools/rtools40-i686.exe' # noqa: disable=E501
|
||||
elif version == '3.5':
|
||||
url = 'https://cran.r-project.org/bin/windows/Rtools/Rtools35.exe'
|
||||
return url
|
||||
|
||||
|
||||
def get_toolchain_version(name: str, version: str) -> str:
|
||||
"""Toolchain version."""
|
||||
toolchain_folder = ''
|
||||
if platform.system() == 'Windows':
|
||||
toolchain_folder = '{}{}'.format(name, version.replace('.', ''))
|
||||
|
||||
return toolchain_folder
|
||||
|
||||
|
||||
def run_rtools_install(args: Dict[str, Any]) -> None:
|
||||
"""Main."""
|
||||
if platform.system() not in {'Windows'}:
|
||||
raise NotImplementedError(
|
||||
'Download for the C++ toolchain '
|
||||
'on the current platform has not '
|
||||
f'been implemented: {platform.system()}'
|
||||
)
|
||||
toolchain = get_toolchain_name()
|
||||
version = args['version']
|
||||
if version is None:
|
||||
version = latest_version()
|
||||
version = normalize_version(version)
|
||||
print("C++ toolchain '{}' version: {}".format(toolchain, version))
|
||||
|
||||
url = get_url(version)
|
||||
|
||||
if 'verbose' in args:
|
||||
verbose = args['verbose']
|
||||
else:
|
||||
verbose = False
|
||||
|
||||
install_dir = args['dir']
|
||||
if install_dir is None:
|
||||
install_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
|
||||
validate_dir(install_dir)
|
||||
print('Install directory: {}'.format(install_dir))
|
||||
|
||||
if 'progress' in args:
|
||||
progress = args['progress']
|
||||
else:
|
||||
progress = False
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
silent = 'silent' in args
|
||||
# force silent == False for 4.0 version
|
||||
if 'silent' not in args and version in ('4.0', '4', '40'):
|
||||
silent = False
|
||||
else:
|
||||
silent = False
|
||||
|
||||
toolchain_folder = get_toolchain_version(toolchain, version)
|
||||
with pushd(install_dir):
|
||||
if is_installed(toolchain_folder, version):
|
||||
print('C++ toolchain {} already installed'.format(toolchain_folder))
|
||||
else:
|
||||
if os.path.exists(toolchain_folder):
|
||||
shutil.rmtree(toolchain_folder, ignore_errors=False)
|
||||
retrieve_toolchain(
|
||||
toolchain_folder + EXTENSION, url, progress=progress
|
||||
)
|
||||
install_version(
|
||||
toolchain_folder,
|
||||
toolchain_folder + EXTENSION,
|
||||
version,
|
||||
silent,
|
||||
verbose,
|
||||
)
|
||||
if (
|
||||
'no-make' not in args
|
||||
and (platform.system() == 'Windows')
|
||||
and (version in ('4.0', '4', '40'))
|
||||
):
|
||||
if os.path.exists(
|
||||
os.path.join(
|
||||
toolchain_folder, 'mingw64', 'bin', 'mingw32-make.exe'
|
||||
)
|
||||
):
|
||||
print('mingw32-make.exe already installed')
|
||||
else:
|
||||
install_mingw32_make(toolchain_folder, verbose)
|
||||
|
||||
|
||||
def parse_cmdline_args() -> Dict[str, Any]:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--version', '-v', help="version, defaults to latest")
|
||||
parser.add_argument(
|
||||
'--dir', '-d', help="install directory, defaults to '~/.cmdstan"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--silent',
|
||||
'-s',
|
||||
action='store_true',
|
||||
help="install with /VERYSILENT instead of /SILENT for RTools",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-make',
|
||||
'-m',
|
||||
action='store_false',
|
||||
help="don't install mingw32-make (Windows RTools 4.0 only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help="flag, when specified prints output from RTools build process",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--progress',
|
||||
action='store_true',
|
||||
help="flag, when specified show progress bar for CmdStan download",
|
||||
)
|
||||
return vars(parser.parse_args(sys.argv[1:]))
|
||||
|
||||
|
||||
def __main__() -> None:
|
||||
run_rtools_install(parse_cmdline_args())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
__main__()
|
||||
2173
.venv/lib/python3.12/site-packages/cmdstanpy/model.py
Normal file
2173
.venv/lib/python3.12/site-packages/cmdstanpy/model.py
Normal file
File diff suppressed because it is too large
Load Diff
49
.venv/lib/python3.12/site-packages/cmdstanpy/progress.py
Normal file
49
.venv/lib/python3.12/site-packages/cmdstanpy/progress.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""
|
||||
Record tqdm progress bar fail during session
|
||||
"""
|
||||
import functools
|
||||
import logging
|
||||
|
||||
_SHOW_PROGRESS: bool = True
|
||||
|
||||
|
||||
def allow_show_progress() -> bool:
|
||||
"""Return False if any progressbar errors have occurred this session"""
|
||||
return _SHOW_PROGRESS
|
||||
|
||||
|
||||
def _disable_progress(e: Exception) -> None:
|
||||
"""Print an exception and disable progress bars for this session"""
|
||||
# pylint: disable=global-statement
|
||||
global _SHOW_PROGRESS
|
||||
if _SHOW_PROGRESS:
|
||||
_SHOW_PROGRESS = False
|
||||
logging.getLogger('cmdstanpy').error(
|
||||
'Error in progress bar initialization:\n'
|
||||
'\t%s\n'
|
||||
'Disabling progress bars for this session',
|
||||
str(e),
|
||||
)
|
||||
|
||||
|
||||
def wrap_callback(func): # type: ignore
|
||||
"""Wrap a callback generator so it fails safely"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def safe_progress(*args, **kwargs): # type: ignore
|
||||
# pylint: disable=unused-argument
|
||||
def callback(*args, **kwargs): # type: ignore
|
||||
# totally empty callback
|
||||
return None
|
||||
|
||||
if not allow_show_progress():
|
||||
return callback
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
# pylint: disable=broad-except
|
||||
except Exception as e:
|
||||
_disable_progress(e)
|
||||
return callback
|
||||
|
||||
return safe_progress
|
||||
269
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/__init__.py
Normal file
269
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/__init__.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""Container objects for results of CmdStan run(s)."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from cmdstanpy.cmdstan_args import (
|
||||
CmdStanArgs,
|
||||
LaplaceArgs,
|
||||
OptimizeArgs,
|
||||
PathfinderArgs,
|
||||
SamplerArgs,
|
||||
VariationalArgs,
|
||||
)
|
||||
from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config
|
||||
|
||||
from .gq import CmdStanGQ
|
||||
from .laplace import CmdStanLaplace
|
||||
from .mcmc import CmdStanMCMC
|
||||
from .metadata import InferenceMetadata
|
||||
from .mle import CmdStanMLE
|
||||
from .pathfinder import CmdStanPathfinder
|
||||
from .runset import RunSet
|
||||
from .vb import CmdStanVB
|
||||
|
||||
__all__ = [
|
||||
"RunSet",
|
||||
"InferenceMetadata",
|
||||
"CmdStanMCMC",
|
||||
"CmdStanMLE",
|
||||
"CmdStanVB",
|
||||
"CmdStanGQ",
|
||||
"CmdStanLaplace",
|
||||
"CmdStanPathfinder",
|
||||
]
|
||||
|
||||
|
||||
def from_csv(
|
||||
path: Union[str, List[str], os.PathLike, None] = None,
|
||||
method: Optional[str] = None,
|
||||
) -> Union[
|
||||
CmdStanMCMC, CmdStanMLE, CmdStanVB, CmdStanPathfinder, CmdStanLaplace, None
|
||||
]:
|
||||
"""
|
||||
Instantiate a CmdStan object from a the Stan CSV files from a CmdStan run.
|
||||
CSV files are specified from either a list of Stan CSV files or a single
|
||||
filepath which can be either a directory name, a Stan CSV filename, or
|
||||
a pathname pattern (i.e., a Python glob). The optional argument 'method'
|
||||
checks that the CSV files were produced by that method.
|
||||
Stan CSV files from CmdStan methods 'sample', 'optimize', and 'variational'
|
||||
result in objects of class CmdStanMCMC, CmdStanMLE, and CmdStanVB,
|
||||
respectively.
|
||||
|
||||
:param path: directory path
|
||||
:param method: method name (optional)
|
||||
|
||||
:return: either a CmdStanMCMC, CmdStanMLE, or CmdStanVB object
|
||||
"""
|
||||
if path is None:
|
||||
raise ValueError('Must specify path to Stan CSV files.')
|
||||
if method is not None and method not in [
|
||||
'sample',
|
||||
'optimize',
|
||||
'variational',
|
||||
'laplace',
|
||||
'pathfinder',
|
||||
]:
|
||||
raise ValueError(
|
||||
'Bad method argument {}, must be one of: '
|
||||
'"sample", "optimize", "variational"'.format(method)
|
||||
)
|
||||
|
||||
csvfiles = []
|
||||
if isinstance(path, list):
|
||||
csvfiles = path
|
||||
elif isinstance(path, str) and '*' in path:
|
||||
splits = os.path.split(path)
|
||||
if splits[0] is not None:
|
||||
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
|
||||
raise ValueError(
|
||||
'Invalid path specification, {} '
|
||||
' unknown directory: {}'.format(path, splits[0])
|
||||
)
|
||||
csvfiles = glob.glob(path)
|
||||
elif isinstance(path, (str, os.PathLike)):
|
||||
if os.path.exists(path) and os.path.isdir(path):
|
||||
for file in os.listdir(path):
|
||||
if os.path.splitext(file)[1] == ".csv":
|
||||
csvfiles.append(os.path.join(path, file))
|
||||
elif os.path.exists(path):
|
||||
csvfiles.append(str(path))
|
||||
else:
|
||||
raise ValueError('Invalid path specification: {}'.format(path))
|
||||
else:
|
||||
raise ValueError('Invalid path specification: {}'.format(path))
|
||||
|
||||
if len(csvfiles) == 0:
|
||||
raise ValueError('No CSV files found in directory {}'.format(path))
|
||||
for file in csvfiles:
|
||||
if not (os.path.exists(file) and os.path.splitext(file)[1] == ".csv"):
|
||||
raise ValueError(
|
||||
'Bad CSV file path spec,'
|
||||
' includes non-csv file: {}'.format(file)
|
||||
)
|
||||
|
||||
config_dict: Dict[str, Any] = {}
|
||||
try:
|
||||
with open(csvfiles[0], 'r') as fd:
|
||||
scan_config(fd, config_dict, 0)
|
||||
except (IOError, OSError, PermissionError) as e:
|
||||
raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e
|
||||
if 'model' not in config_dict or 'method' not in config_dict:
|
||||
raise ValueError("File {} is not a Stan CSV file.".format(csvfiles[0]))
|
||||
if method is not None and method != config_dict['method']:
|
||||
raise ValueError(
|
||||
'Expecting Stan CSV output files from method {}, '
|
||||
' found outputs from method {}'.format(
|
||||
method, config_dict['method']
|
||||
)
|
||||
)
|
||||
try:
|
||||
if config_dict['method'] == 'sample':
|
||||
chains = len(csvfiles)
|
||||
sampler_args = SamplerArgs(
|
||||
iter_sampling=config_dict['num_samples'],
|
||||
iter_warmup=config_dict['num_warmup'],
|
||||
thin=config_dict['thin'],
|
||||
save_warmup=config_dict['save_warmup'],
|
||||
)
|
||||
# bugfix 425, check for fixed_params output
|
||||
try:
|
||||
check_sampler_csv(
|
||||
csvfiles[0],
|
||||
iter_sampling=config_dict['num_samples'],
|
||||
iter_warmup=config_dict['num_warmup'],
|
||||
thin=config_dict['thin'],
|
||||
save_warmup=config_dict['save_warmup'],
|
||||
)
|
||||
except ValueError:
|
||||
try:
|
||||
check_sampler_csv(
|
||||
csvfiles[0],
|
||||
is_fixed_param=True,
|
||||
iter_sampling=config_dict['num_samples'],
|
||||
iter_warmup=config_dict['num_warmup'],
|
||||
thin=config_dict['thin'],
|
||||
save_warmup=config_dict['save_warmup'],
|
||||
)
|
||||
sampler_args = SamplerArgs(
|
||||
iter_sampling=config_dict['num_samples'],
|
||||
iter_warmup=config_dict['num_warmup'],
|
||||
thin=config_dict['thin'],
|
||||
save_warmup=config_dict['save_warmup'],
|
||||
fixed_param=True,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
'Invalid or corrupt Stan CSV output file, '
|
||||
) from e
|
||||
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=[x + 1 for x in range(chains)],
|
||||
method_args=sampler_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args, chains=chains)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
fit = CmdStanMCMC(runset)
|
||||
fit.draws()
|
||||
return fit
|
||||
elif config_dict['method'] == 'optimize':
|
||||
if 'algorithm' not in config_dict:
|
||||
raise ValueError(
|
||||
"Cannot find optimization algorithm"
|
||||
" in file {}.".format(csvfiles[0])
|
||||
)
|
||||
optimize_args = OptimizeArgs(
|
||||
algorithm=config_dict['algorithm'],
|
||||
save_iterations=config_dict['save_iterations'],
|
||||
jacobian=config_dict.get('jacobian', 0),
|
||||
)
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=None,
|
||||
method_args=optimize_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
return CmdStanMLE(runset)
|
||||
elif config_dict['method'] == 'variational':
|
||||
if 'algorithm' not in config_dict:
|
||||
raise ValueError(
|
||||
"Cannot find variational algorithm"
|
||||
" in file {}.".format(csvfiles[0])
|
||||
)
|
||||
variational_args = VariationalArgs(
|
||||
algorithm=config_dict['algorithm'],
|
||||
iter=config_dict['iter'],
|
||||
grad_samples=config_dict['grad_samples'],
|
||||
elbo_samples=config_dict['elbo_samples'],
|
||||
eta=config_dict['eta'],
|
||||
tol_rel_obj=config_dict['tol_rel_obj'],
|
||||
eval_elbo=config_dict['eval_elbo'],
|
||||
output_samples=config_dict['output_samples'],
|
||||
)
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=None,
|
||||
method_args=variational_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
return CmdStanVB(runset)
|
||||
elif config_dict['method'] == 'laplace':
|
||||
laplace_args = LaplaceArgs(
|
||||
mode=config_dict['mode'],
|
||||
draws=config_dict['draws'],
|
||||
jacobian=config_dict['jacobian'],
|
||||
)
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=None,
|
||||
method_args=laplace_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
mode: CmdStanMLE = from_csv(
|
||||
config_dict['mode'],
|
||||
method='optimize',
|
||||
) # type: ignore
|
||||
return CmdStanLaplace(runset, mode=mode)
|
||||
elif config_dict['method'] == 'pathfinder':
|
||||
pathfinder_args = PathfinderArgs(
|
||||
num_draws=config_dict['num_draws'],
|
||||
num_paths=config_dict['num_paths'],
|
||||
)
|
||||
cmdstan_args = CmdStanArgs(
|
||||
model_name=config_dict['model'],
|
||||
model_exe=config_dict['model'],
|
||||
chain_ids=None,
|
||||
method_args=pathfinder_args,
|
||||
)
|
||||
runset = RunSet(args=cmdstan_args)
|
||||
runset._csv_files = csvfiles
|
||||
for i in range(len(runset._retcodes)):
|
||||
runset._set_retcode(i, 0)
|
||||
return CmdStanPathfinder(runset)
|
||||
else:
|
||||
get_logger().info(
|
||||
'Unable to process CSV output files from method %s.',
|
||||
(config_dict['method']),
|
||||
)
|
||||
return None
|
||||
except (IOError, OSError, PermissionError) as e:
|
||||
raise ValueError(
|
||||
'An error occurred processing the CSV files:\n\t{}'.format(str(e))
|
||||
) from e
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
734
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/gq.py
Normal file
734
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/gq.py
Normal file
@ -0,0 +1,734 @@
|
||||
"""
|
||||
Container for the result of running the
|
||||
generate quantities (GQ) method
|
||||
"""
|
||||
|
||||
from collections import Counter
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
Hashable,
|
||||
List,
|
||||
MutableMapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import xarray as xr
|
||||
|
||||
XARRAY_INSTALLED = True
|
||||
except ImportError:
|
||||
XARRAY_INSTALLED = False
|
||||
|
||||
|
||||
from cmdstanpy.cmdstan_args import Method
|
||||
from cmdstanpy.utils import build_xarray_data, flatten_chains, get_logger
|
||||
from cmdstanpy.utils.stancsv import scan_generic_csv
|
||||
|
||||
from .mcmc import CmdStanMCMC
|
||||
from .metadata import InferenceMetadata
|
||||
from .mle import CmdStanMLE
|
||||
from .runset import RunSet
|
||||
from .vb import CmdStanVB
|
||||
|
||||
Fit = TypeVar('Fit', CmdStanMCMC, CmdStanMLE, CmdStanVB)
|
||||
|
||||
|
||||
class CmdStanGQ(Generic[Fit]):
|
||||
"""
|
||||
Container for outputs from CmdStan generate_quantities run.
|
||||
Created by :meth:`CmdStanModel.generate_quantities`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runset: RunSet,
|
||||
previous_fit: Fit,
|
||||
) -> None:
|
||||
"""Initialize object."""
|
||||
if not runset.method == Method.GENERATE_QUANTITIES:
|
||||
raise ValueError(
|
||||
'Wrong runset method, expecting generate_quantities runset, '
|
||||
'found method {}'.format(runset.method)
|
||||
)
|
||||
self.runset = runset
|
||||
|
||||
self.previous_fit: Fit = previous_fit
|
||||
|
||||
self._draws: np.ndarray = np.array(())
|
||||
config = self._validate_csv_files()
|
||||
self._metadata = InferenceMetadata(config)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr = 'CmdStanGQ: model={} chains={}{}'.format(
|
||||
self.runset.model,
|
||||
self.chains,
|
||||
self.runset._args.method_args.compose(0, cmd=[]),
|
||||
)
|
||||
repr = '{}\n csv_files:\n\t{}\n output_files:\n\t{}'.format(
|
||||
repr,
|
||||
'\n\t'.join(self.runset.csv_files),
|
||||
'\n\t'.join(self.runset.stdout_files),
|
||||
)
|
||||
return repr
|
||||
|
||||
def __getattr__(self, attr: str) -> np.ndarray:
|
||||
"""Synonymous with ``fit.stan_variable(attr)"""
|
||||
if attr.startswith("_"):
|
||||
raise AttributeError(f"Unknown variable name {attr}")
|
||||
try:
|
||||
return self.stan_variable(attr)
|
||||
except ValueError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise AttributeError(*e.args)
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
# This function returns the mapping of objects to serialize with pickle.
|
||||
# See https://docs.python.org/3/library/pickle.html#object.__getstate__
|
||||
# for details. We call _assemble_generated_quantities to ensure
|
||||
# the data are loaded prior to serialization.
|
||||
self._assemble_generated_quantities()
|
||||
return self.__dict__
|
||||
|
||||
def _validate_csv_files(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks that Stan CSV output files for all chains are consistent
|
||||
and returns dict containing config and column names.
|
||||
|
||||
Raises exception when inconsistencies detected.
|
||||
"""
|
||||
dzero = {}
|
||||
for i in range(self.chains):
|
||||
if i == 0:
|
||||
dzero = scan_generic_csv(
|
||||
path=self.runset.csv_files[i],
|
||||
)
|
||||
else:
|
||||
drest = scan_generic_csv(
|
||||
path=self.runset.csv_files[i],
|
||||
)
|
||||
for key in dzero:
|
||||
if (
|
||||
key
|
||||
not in [
|
||||
'id',
|
||||
'fitted_params',
|
||||
'diagnostic_file',
|
||||
'metric_file',
|
||||
'profile_file',
|
||||
'init',
|
||||
'seed',
|
||||
'start_datetime',
|
||||
]
|
||||
and dzero[key] != drest[key]
|
||||
):
|
||||
raise ValueError(
|
||||
'CmdStan config mismatch in Stan CSV file {}: '
|
||||
'arg {} is {}, expected {}'.format(
|
||||
self.runset.csv_files[i],
|
||||
key,
|
||||
dzero[key],
|
||||
drest[key],
|
||||
)
|
||||
)
|
||||
return dzero
|
||||
|
||||
@property
|
||||
def chains(self) -> int:
|
||||
"""Number of chains."""
|
||||
return self.runset.chains
|
||||
|
||||
@property
|
||||
def chain_ids(self) -> List[int]:
|
||||
"""Chain ids."""
|
||||
return self.runset.chain_ids
|
||||
|
||||
@property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
"""
|
||||
Names of generated quantities of interest.
|
||||
"""
|
||||
return self._metadata.cmdstan_config['column_names'] # type: ignore
|
||||
|
||||
@property
|
||||
def metadata(self) -> InferenceMetadata:
|
||||
"""
|
||||
Returns object which contains CmdStan configuration as well as
|
||||
information about the names and structure of the inference method
|
||||
and model output variables.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
def draws(
|
||||
self,
|
||||
*,
|
||||
inc_warmup: bool = False,
|
||||
inc_iterations: bool = False,
|
||||
concat_chains: bool = False,
|
||||
inc_sample: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Returns a numpy.ndarray over the generated quantities draws from
|
||||
all chains which is stored column major so that the values
|
||||
for a parameter are contiguous in memory, likewise all draws from
|
||||
a chain are contiguous. By default, returns a 3D array arranged
|
||||
(draws, chains, columns); parameter ``concat_chains=True`` will
|
||||
return a 2D array where all chains are flattened into a single column,
|
||||
preserving chain order, so that given M chains of N draws,
|
||||
the first N draws are from chain 1, ..., and the the last N draws
|
||||
are from chain M.
|
||||
|
||||
:param inc_warmup: When ``True`` and the warmup draws are present in
|
||||
the output, i.e., the sampler was run with ``save_warmup=True``,
|
||||
then the warmup draws are included. Default value is ``False``.
|
||||
|
||||
:param concat_chains: When ``True`` return a 2D array flattening all
|
||||
all draws from all chains. Default value is ``False``.
|
||||
|
||||
:param inc_sample: When ``True`` include all columns in the previous_fit
|
||||
draws array as well, excepting columns for variables already present
|
||||
in the generated quantities drawset. Default value is ``False``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanGQ.draws_pd
|
||||
CmdStanGQ.draws_xr
|
||||
CmdStanMCMC.draws
|
||||
"""
|
||||
self._assemble_generated_quantities()
|
||||
inc_warmup |= inc_iterations
|
||||
if inc_warmup:
|
||||
if (
|
||||
isinstance(self.previous_fit, CmdStanMCMC)
|
||||
and not self.previous_fit._save_warmup
|
||||
):
|
||||
get_logger().warning(
|
||||
"Sample doesn't contain draws from warmup iterations,"
|
||||
' rerun sampler with "save_warmup=True".'
|
||||
)
|
||||
elif (
|
||||
isinstance(self.previous_fit, CmdStanMLE)
|
||||
and not self.previous_fit._save_iterations
|
||||
):
|
||||
get_logger().warning(
|
||||
"MLE doesn't contain draws from pre-convergence iterations,"
|
||||
' rerun optimization with "save_iterations=True".'
|
||||
)
|
||||
elif isinstance(self.previous_fit, CmdStanVB):
|
||||
get_logger().warning(
|
||||
"Variational fit doesn't make sense with argument "
|
||||
'"inc_warmup=True"'
|
||||
)
|
||||
|
||||
if inc_sample:
|
||||
cols_1 = self.previous_fit.column_names
|
||||
cols_2 = self.column_names
|
||||
dups = [
|
||||
item
|
||||
for item, count in Counter(cols_1 + cols_2).items()
|
||||
if count > 1
|
||||
]
|
||||
drop_cols: List[int] = []
|
||||
for dup in dups:
|
||||
drop_cols.extend(
|
||||
self.previous_fit._metadata.stan_vars[dup].columns()
|
||||
)
|
||||
|
||||
start_idx, _ = self._draws_start(inc_warmup)
|
||||
previous_draws = self._previous_draws(True)
|
||||
if concat_chains and inc_sample:
|
||||
return flatten_chains(
|
||||
np.dstack(
|
||||
(
|
||||
np.delete(previous_draws, drop_cols, axis=1),
|
||||
self._draws,
|
||||
)
|
||||
)[start_idx:, :, :]
|
||||
)
|
||||
if concat_chains:
|
||||
return flatten_chains(self._draws[start_idx:, :, :])
|
||||
if inc_sample:
|
||||
return np.dstack(
|
||||
(
|
||||
np.delete(previous_draws, drop_cols, axis=1),
|
||||
self._draws,
|
||||
)
|
||||
)[start_idx:, :, :]
|
||||
return self._draws[start_idx:, :, :]
|
||||
|
||||
def draws_pd(
|
||||
self,
|
||||
vars: Union[List[str], str, None] = None,
|
||||
inc_warmup: bool = False,
|
||||
inc_sample: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Returns the generated quantities draws as a pandas DataFrame.
|
||||
Flattens all chains into single column. Container variables
|
||||
(array, vector, matrix) will span multiple columns, one column
|
||||
per element. E.g. variable 'matrix[2,2] foo' spans 4 columns:
|
||||
'foo[1,1], ... foo[2,2]'.
|
||||
|
||||
:param vars: optional list of variable names.
|
||||
|
||||
:param inc_warmup: When ``True`` and the warmup draws are present in
|
||||
the output, i.e., the sampler was run with ``save_warmup=True``,
|
||||
then the warmup draws are included. Default value is ``False``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanGQ.draws
|
||||
CmdStanGQ.draws_xr
|
||||
CmdStanMCMC.draws_pd
|
||||
"""
|
||||
if vars is not None:
|
||||
if isinstance(vars, str):
|
||||
vars_list = [vars]
|
||||
else:
|
||||
vars_list = vars
|
||||
|
||||
vars_list = list(dict.fromkeys(vars_list))
|
||||
|
||||
if inc_warmup:
|
||||
if (
|
||||
isinstance(self.previous_fit, CmdStanMCMC)
|
||||
and not self.previous_fit._save_warmup
|
||||
):
|
||||
get_logger().warning(
|
||||
"Sample doesn't contain draws from warmup iterations,"
|
||||
' rerun sampler with "save_warmup=True".'
|
||||
)
|
||||
elif (
|
||||
isinstance(self.previous_fit, CmdStanMLE)
|
||||
and not self.previous_fit._save_iterations
|
||||
):
|
||||
get_logger().warning(
|
||||
"MLE doesn't contain draws from pre-convergence iterations,"
|
||||
' rerun optimization with "save_iterations=True".'
|
||||
)
|
||||
elif isinstance(self.previous_fit, CmdStanVB):
|
||||
get_logger().warning(
|
||||
"Variational fit doesn't make sense with argument "
|
||||
'"inc_warmup=True"'
|
||||
)
|
||||
|
||||
self._assemble_generated_quantities()
|
||||
|
||||
all_columns = ['chain__', 'iter__', 'draw__'] + list(self.column_names)
|
||||
|
||||
gq_cols: List[str] = []
|
||||
mcmc_vars: List[str] = []
|
||||
if vars is not None:
|
||||
for var in vars_list:
|
||||
if var in self._metadata.stan_vars:
|
||||
info = self._metadata.stan_vars[var]
|
||||
gq_cols.extend(
|
||||
self.column_names[info.start_idx : info.end_idx]
|
||||
)
|
||||
elif (
|
||||
inc_sample and var in self.previous_fit._metadata.stan_vars
|
||||
):
|
||||
info = self.previous_fit._metadata.stan_vars[var]
|
||||
mcmc_vars.extend(
|
||||
self.previous_fit.column_names[
|
||||
info.start_idx : info.end_idx
|
||||
]
|
||||
)
|
||||
elif var in ['chain__', 'iter__', 'draw__']:
|
||||
gq_cols.append(var)
|
||||
else:
|
||||
raise ValueError('Unknown variable: {}'.format(var))
|
||||
else:
|
||||
gq_cols = all_columns
|
||||
vars_list = gq_cols
|
||||
|
||||
previous_draws_pd = self._previous_draws_pd(mcmc_vars, inc_warmup)
|
||||
|
||||
draws = self.draws(inc_warmup=inc_warmup)
|
||||
# add long-form columns for chain, iteration, draw
|
||||
n_draws, n_chains, _ = draws.shape
|
||||
chains_col = (
|
||||
np.repeat(np.arange(1, n_chains + 1), n_draws)
|
||||
.reshape(1, n_chains, n_draws)
|
||||
.T
|
||||
)
|
||||
iter_col = (
|
||||
np.tile(np.arange(1, n_draws + 1), n_chains)
|
||||
.reshape(1, n_chains, n_draws)
|
||||
.T
|
||||
)
|
||||
draw_col = (
|
||||
np.arange(1, (n_draws * n_chains) + 1)
|
||||
.reshape(1, n_chains, n_draws)
|
||||
.T
|
||||
)
|
||||
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
|
||||
|
||||
draws_pd = pd.DataFrame(
|
||||
data=flatten_chains(draws),
|
||||
columns=all_columns,
|
||||
)
|
||||
|
||||
if inc_sample and mcmc_vars:
|
||||
if gq_cols:
|
||||
return pd.concat(
|
||||
[
|
||||
previous_draws_pd,
|
||||
draws_pd[gq_cols],
|
||||
],
|
||||
axis='columns',
|
||||
)[vars_list]
|
||||
else:
|
||||
return previous_draws_pd
|
||||
elif inc_sample and vars is None:
|
||||
cols_1 = list(previous_draws_pd.columns)
|
||||
cols_2 = list(draws_pd.columns)
|
||||
dups = [
|
||||
item
|
||||
for item, count in Counter(cols_1 + cols_2).items()
|
||||
if count > 1
|
||||
]
|
||||
return pd.concat(
|
||||
[
|
||||
previous_draws_pd.drop(columns=dups).reset_index(drop=True),
|
||||
draws_pd,
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
elif gq_cols:
|
||||
return draws_pd[gq_cols]
|
||||
|
||||
return draws_pd
|
||||
|
||||
@overload
|
||||
def draws_xr(
|
||||
self: Union["CmdStanGQ[CmdStanMLE]", "CmdStanGQ[CmdStanVB]"],
|
||||
vars: Union[str, List[str], None] = None,
|
||||
inc_warmup: bool = False,
|
||||
inc_sample: bool = False,
|
||||
) -> NoReturn:
|
||||
...
|
||||
|
||||
@overload
|
||||
def draws_xr(
|
||||
self: "CmdStanGQ[CmdStanMCMC]",
|
||||
vars: Union[str, List[str], None] = None,
|
||||
inc_warmup: bool = False,
|
||||
inc_sample: bool = False,
|
||||
) -> "xr.Dataset":
|
||||
...
|
||||
|
||||
def draws_xr(
|
||||
self,
|
||||
vars: Union[str, List[str], None] = None,
|
||||
inc_warmup: bool = False,
|
||||
inc_sample: bool = False,
|
||||
) -> "xr.Dataset":
|
||||
"""
|
||||
Returns the generated quantities draws as a xarray Dataset.
|
||||
|
||||
This method can only be called when the underlying fit was made
|
||||
through sampling, it cannot be used on MLE or VB outputs.
|
||||
|
||||
:param vars: optional list of variable names.
|
||||
|
||||
:param inc_warmup: When ``True`` and the warmup draws are present in
|
||||
the MCMC sample, then the warmup draws are included.
|
||||
Default value is ``False``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanGQ.draws
|
||||
CmdStanGQ.draws_pd
|
||||
CmdStanMCMC.draws_xr
|
||||
"""
|
||||
if not XARRAY_INSTALLED:
|
||||
raise RuntimeError(
|
||||
'Package "xarray" is not installed, cannot produce draws array.'
|
||||
)
|
||||
if not isinstance(self.previous_fit, CmdStanMCMC):
|
||||
raise RuntimeError(
|
||||
'Method "draws_xr" is only available when '
|
||||
'original fit is done via Sampling.'
|
||||
)
|
||||
mcmc_vars_list = []
|
||||
dup_vars = []
|
||||
if vars is not None:
|
||||
if isinstance(vars, str):
|
||||
vars_list = [vars]
|
||||
else:
|
||||
vars_list = vars
|
||||
for var in vars_list:
|
||||
if var not in self._metadata.stan_vars:
|
||||
if inc_sample and (
|
||||
var in self.previous_fit._metadata.stan_vars
|
||||
):
|
||||
mcmc_vars_list.append(var)
|
||||
dup_vars.append(var)
|
||||
else:
|
||||
raise ValueError('Unknown variable: {}'.format(var))
|
||||
else:
|
||||
vars_list = list(self._metadata.stan_vars.keys())
|
||||
if inc_sample:
|
||||
for var in self.previous_fit._metadata.stan_vars.keys():
|
||||
if var not in vars_list and var not in mcmc_vars_list:
|
||||
mcmc_vars_list.append(var)
|
||||
for var in dup_vars:
|
||||
vars_list.remove(var)
|
||||
|
||||
self._assemble_generated_quantities()
|
||||
|
||||
num_draws = self.previous_fit.num_draws_sampling
|
||||
sample_config = self.previous_fit._metadata.cmdstan_config
|
||||
attrs: MutableMapping[Hashable, Any] = {
|
||||
"stan_version": f"{sample_config['stan_version_major']}."
|
||||
f"{sample_config['stan_version_minor']}."
|
||||
f"{sample_config['stan_version_patch']}",
|
||||
"model": sample_config["model"],
|
||||
"num_draws_sampling": num_draws,
|
||||
}
|
||||
if inc_warmup and sample_config['save_warmup']:
|
||||
num_draws += self.previous_fit.num_draws_warmup
|
||||
attrs["num_draws_warmup"] = self.previous_fit.num_draws_warmup
|
||||
|
||||
data: MutableMapping[Hashable, Any] = {}
|
||||
coordinates: MutableMapping[Hashable, Any] = {
|
||||
"chain": self.chain_ids,
|
||||
"draw": np.arange(num_draws),
|
||||
}
|
||||
|
||||
for var in vars_list:
|
||||
build_xarray_data(
|
||||
data,
|
||||
self._metadata.stan_vars[var],
|
||||
self.draws(inc_warmup=inc_warmup),
|
||||
)
|
||||
if inc_sample:
|
||||
for var in mcmc_vars_list:
|
||||
build_xarray_data(
|
||||
data,
|
||||
self.previous_fit._metadata.stan_vars[var],
|
||||
self.previous_fit.draws(inc_warmup=inc_warmup),
|
||||
)
|
||||
|
||||
return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
|
||||
'chain', 'draw', ...
|
||||
)
|
||||
|
||||
def stan_variable(self, var: str, **kwargs: bool) -> np.ndarray:
|
||||
"""
|
||||
Return a numpy.ndarray which contains the set of draws
|
||||
for the named Stan program variable. Flattens the chains,
|
||||
leaving the draws in chain order. The first array dimension,
|
||||
corresponds to number of draws in the sample.
|
||||
The remaining dimensions correspond to
|
||||
the shape of the Stan program variable.
|
||||
|
||||
Underlyingly draws are in chain order, i.e., for a sample with
|
||||
N chains of M draws each, the first M array elements are from chain 1,
|
||||
the next M are from chain 2, and the last M elements are from chain N.
|
||||
|
||||
* If the variable is a scalar variable, the return array has shape
|
||||
( draws * chains, 1).
|
||||
* If the variable is a vector, the return array has shape
|
||||
( draws * chains, len(vector))
|
||||
* If the variable is a matrix, the return array has shape
|
||||
( draws * chains, size(dim 1), size(dim 2) )
|
||||
* If the variable is an array with N dimensions, the return array
|
||||
has shape ( draws * chains, size(dim 1), ..., size(dim N))
|
||||
|
||||
For example, if the Stan program variable ``theta`` is a 3x3 matrix,
|
||||
and the sample consists of 4 chains with 1000 post-warmup draws,
|
||||
this function will return a numpy.ndarray with shape (4000,3,3).
|
||||
|
||||
This functionaltiy is also available via a shortcut using ``.`` -
|
||||
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
|
||||
|
||||
:param var: variable name
|
||||
|
||||
:param kwargs: Additional keyword arguments are passed to the underlying
|
||||
fit's ``stan_variable`` method if the variable is not a generated
|
||||
quantity.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanGQ.stan_variables
|
||||
CmdStanMCMC.stan_variable
|
||||
CmdStanMLE.stan_variable
|
||||
CmdStanPathfinder.stan_variable
|
||||
CmdStanVB.stan_variable
|
||||
CmdStanLaplace.stan_variable
|
||||
"""
|
||||
model_var_names = self.previous_fit._metadata.stan_vars.keys()
|
||||
gq_var_names = self._metadata.stan_vars.keys()
|
||||
if not (var in model_var_names or var in gq_var_names):
|
||||
raise ValueError(
|
||||
f'Unknown variable name: {var}\n'
|
||||
'Available variables are '
|
||||
+ ", ".join(model_var_names | gq_var_names)
|
||||
)
|
||||
if var not in gq_var_names:
|
||||
# TODO(2.0) atleast1d may not be needed
|
||||
return np.atleast_1d( # type: ignore
|
||||
self.previous_fit.stan_variable(var, **kwargs)
|
||||
)
|
||||
|
||||
# is gq variable
|
||||
self._assemble_generated_quantities()
|
||||
|
||||
draw1, _ = self._draws_start(
|
||||
inc_warmup=kwargs.get('inc_warmup', False)
|
||||
or kwargs.get('inc_iterations', False)
|
||||
)
|
||||
draws = flatten_chains(self._draws[draw1:])
|
||||
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(draws)
|
||||
return out
|
||||
|
||||
def stan_variables(self, **kwargs: bool) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Return a dictionary mapping Stan program variables names
|
||||
to the corresponding numpy.ndarray containing the inferred values.
|
||||
|
||||
:param kwargs: Additional keyword arguments are passed to the underlying
|
||||
fit's ``stan_variable`` method if the variable is not a generated
|
||||
quantity.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanGQ.stan_variable
|
||||
CmdStanMCMC.stan_variables
|
||||
CmdStanMLE.stan_variables
|
||||
CmdStanPathfinder.stan_variables
|
||||
CmdStanVB.stan_variables
|
||||
CmdStanLaplace.stan_variables
|
||||
"""
|
||||
result = {}
|
||||
sample_var_names = self.previous_fit._metadata.stan_vars.keys()
|
||||
gq_var_names = self._metadata.stan_vars.keys()
|
||||
for name in gq_var_names:
|
||||
result[name] = self.stan_variable(name, **kwargs)
|
||||
for name in sample_var_names:
|
||||
if name not in gq_var_names:
|
||||
result[name] = self.stan_variable(name, **kwargs)
|
||||
return result
|
||||
|
||||
def _assemble_generated_quantities(self) -> None:
|
||||
if self._draws.shape != (0,):
|
||||
return
|
||||
# use numpy loadtxt
|
||||
_, num_draws = self._draws_start(inc_warmup=True)
|
||||
|
||||
gq_sample: np.ndarray = np.empty(
|
||||
(num_draws, self.chains, len(self.column_names)),
|
||||
dtype=float,
|
||||
order='F',
|
||||
)
|
||||
for chain in range(self.chains):
|
||||
with open(self.runset.csv_files[chain], 'r') as fd:
|
||||
lines = (line for line in fd if not line.startswith('#'))
|
||||
gq_sample[:, chain, :] = np.loadtxt(
|
||||
lines, dtype=np.ndarray, ndmin=2, skiprows=1, delimiter=','
|
||||
)
|
||||
self._draws = gq_sample
|
||||
|
||||
def _draws_start(self, inc_warmup: bool) -> Tuple[int, int]:
|
||||
draw1 = 0
|
||||
p_fit = self.previous_fit
|
||||
if isinstance(p_fit, CmdStanMCMC):
|
||||
num_draws = p_fit.num_draws_sampling
|
||||
if p_fit._save_warmup:
|
||||
if inc_warmup:
|
||||
num_draws += p_fit.num_draws_warmup
|
||||
else:
|
||||
draw1 = p_fit.num_draws_warmup
|
||||
|
||||
elif isinstance(p_fit, CmdStanMLE):
|
||||
num_draws = 1
|
||||
if p_fit._save_iterations:
|
||||
opt_iters = len(p_fit.optimized_iterations_np) # type: ignore
|
||||
if inc_warmup:
|
||||
num_draws = opt_iters
|
||||
else:
|
||||
draw1 = opt_iters - 1
|
||||
else: # CmdStanVB:
|
||||
draw1 = 1 # skip mean
|
||||
num_draws = p_fit.variational_sample.shape[0]
|
||||
if inc_warmup:
|
||||
num_draws += 1
|
||||
|
||||
return draw1, num_draws
|
||||
|
||||
def _previous_draws(self, inc_warmup: bool) -> np.ndarray:
|
||||
"""
|
||||
Extract the draws from self.previous_fit.
|
||||
Return is always 3-d
|
||||
"""
|
||||
p_fit = self.previous_fit
|
||||
if isinstance(p_fit, CmdStanMCMC):
|
||||
return p_fit.draws(inc_warmup=inc_warmup)
|
||||
elif isinstance(p_fit, CmdStanMLE):
|
||||
if inc_warmup and p_fit._save_iterations:
|
||||
return p_fit.optimized_iterations_np[:, None] # type: ignore
|
||||
|
||||
return np.atleast_2d( # type: ignore
|
||||
p_fit.optimized_params_np,
|
||||
)[:, None]
|
||||
else: # CmdStanVB:
|
||||
if inc_warmup:
|
||||
return np.vstack(
|
||||
[p_fit.variational_params_np, p_fit.variational_sample]
|
||||
)[:, None]
|
||||
return p_fit.variational_sample[:, None]
|
||||
|
||||
def _previous_draws_pd(
|
||||
self, vars: List[str], inc_warmup: bool
|
||||
) -> pd.DataFrame:
|
||||
if vars:
|
||||
sel: Union[List[str], slice] = vars
|
||||
else:
|
||||
sel = slice(None, None)
|
||||
|
||||
p_fit = self.previous_fit
|
||||
if isinstance(p_fit, CmdStanMCMC):
|
||||
return p_fit.draws_pd(vars or None, inc_warmup=inc_warmup)
|
||||
|
||||
elif isinstance(p_fit, CmdStanMLE):
|
||||
if inc_warmup and p_fit._save_iterations:
|
||||
return p_fit.optimized_iterations_pd[sel] # type: ignore
|
||||
else:
|
||||
return p_fit.optimized_params_pd[sel]
|
||||
else: # CmdStanVB:
|
||||
return p_fit.variational_sample_pd[sel]
|
||||
|
||||
def save_csvfiles(self, dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Move output CSV files to specified directory. If files were
|
||||
written to the temporary session directory, clean filename.
|
||||
E.g., save 'bernoulli-201912081451-1-5nm6as7u.csv' as
|
||||
'bernoulli-201912081451-1.csv'.
|
||||
|
||||
:param dir: directory path
|
||||
|
||||
See Also
|
||||
--------
|
||||
stanfit.RunSet.save_csvfiles
|
||||
cmdstanpy.from_csv
|
||||
"""
|
||||
self.runset.save_csvfiles(dir)
|
||||
|
||||
# TODO(2.0): remove
|
||||
@property
|
||||
def mcmc_sample(self) -> Union[CmdStanMCMC, CmdStanMLE, CmdStanVB]:
|
||||
get_logger().warning(
|
||||
"Property `mcmc_sample` is deprecated, use `previous_fit` instead"
|
||||
)
|
||||
return self.previous_fit
|
||||
304
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/laplace.py
Normal file
304
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/laplace.py
Normal file
@ -0,0 +1,304 @@
|
||||
"""
|
||||
Container for the result of running a laplace approximation.
|
||||
"""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import xarray as xr
|
||||
|
||||
XARRAY_INSTALLED = True
|
||||
except ImportError:
|
||||
XARRAY_INSTALLED = False
|
||||
|
||||
from cmdstanpy.cmdstan_args import Method
|
||||
from cmdstanpy.utils.data_munging import build_xarray_data
|
||||
from cmdstanpy.utils.stancsv import scan_generic_csv
|
||||
|
||||
from .metadata import InferenceMetadata
|
||||
from .mle import CmdStanMLE
|
||||
from .runset import RunSet
|
||||
|
||||
# TODO list:
|
||||
# - docs and example notebook
|
||||
# - make sure features like standalone GQ are updated/working
|
||||
|
||||
|
||||
class CmdStanLaplace:
|
||||
def __init__(self, runset: RunSet, mode: CmdStanMLE) -> None:
|
||||
"""Initialize object."""
|
||||
if not runset.method == Method.LAPLACE:
|
||||
raise ValueError(
|
||||
'Wrong runset method, expecting laplace runset, '
|
||||
'found method {}'.format(runset.method)
|
||||
)
|
||||
self._runset = runset
|
||||
self._mode = mode
|
||||
|
||||
self._draws: np.ndarray = np.array(())
|
||||
|
||||
config = scan_generic_csv(runset.csv_files[0])
|
||||
self._metadata = InferenceMetadata(config)
|
||||
|
||||
def _assemble_draws(self) -> None:
|
||||
if self._draws.shape != (0,):
|
||||
return
|
||||
|
||||
with open(self._runset.csv_files[0], 'r') as fd:
|
||||
while (fd.readline()).startswith("#"):
|
||||
pass
|
||||
self._draws = np.loadtxt(
|
||||
fd,
|
||||
dtype=float,
|
||||
ndmin=2,
|
||||
delimiter=',',
|
||||
comments="#",
|
||||
)
|
||||
|
||||
def stan_variable(self, var: str) -> np.ndarray:
|
||||
"""
|
||||
Return a numpy.ndarray which contains the estimates for the
|
||||
for the named Stan program variable where the dimensions of the
|
||||
numpy.ndarray match the shape of the Stan program variable.
|
||||
|
||||
This functionaltiy is also available via a shortcut using ``.`` -
|
||||
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
|
||||
|
||||
:param var: variable name
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMLE.stan_variables
|
||||
CmdStanMCMC.stan_variable
|
||||
CmdStanPathfinder.stan_variable
|
||||
CmdStanVB.stan_variable
|
||||
CmdStanGQ.stan_variable
|
||||
"""
|
||||
self._assemble_draws()
|
||||
try:
|
||||
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
|
||||
self._draws
|
||||
)
|
||||
return out
|
||||
except KeyError:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise ValueError(
|
||||
f'Unknown variable name: {var}\n'
|
||||
'Available variables are '
|
||||
+ ", ".join(self._metadata.stan_vars.keys())
|
||||
)
|
||||
|
||||
def stan_variables(self) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Return a dictionary mapping Stan program variables names
|
||||
to the corresponding numpy.ndarray containing the inferred values.
|
||||
|
||||
:param inc_warmup: When ``True`` and the warmup draws are present in
|
||||
the MCMC sample, then the warmup draws are included.
|
||||
Default value is ``False``
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanGQ.stan_variable
|
||||
CmdStanMCMC.stan_variables
|
||||
CmdStanMLE.stan_variables
|
||||
CmdStanPathfinder.stan_variables
|
||||
CmdStanVB.stan_variables
|
||||
"""
|
||||
result = {}
|
||||
for name in self._metadata.stan_vars:
|
||||
result[name] = self.stan_variable(name)
|
||||
return result
|
||||
|
||||
def method_variables(self) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Returns a dictionary of all sampler variables, i.e., all
|
||||
output column names ending in `__`. Assumes that all variables
|
||||
are scalar variables where column name is variable name.
|
||||
Maps each column name to a numpy.ndarray (draws x chains x 1)
|
||||
containing per-draw diagnostic values.
|
||||
"""
|
||||
self._assemble_draws()
|
||||
return {
|
||||
name: var.extract_reshape(self._draws)
|
||||
for name, var in self._metadata.method_vars.items()
|
||||
}
|
||||
|
||||
def draws(self) -> np.ndarray:
|
||||
"""
|
||||
Return a numpy.ndarray containing the draws from the
|
||||
approximate posterior distribution. This is a 2-D array
|
||||
of shape (draws, parameters).
|
||||
"""
|
||||
self._assemble_draws()
|
||||
return self._draws
|
||||
|
||||
def draws_pd(
|
||||
self,
|
||||
vars: Union[List[str], str, None] = None,
|
||||
) -> pd.DataFrame:
|
||||
if vars is not None:
|
||||
if isinstance(vars, str):
|
||||
vars_list = [vars]
|
||||
else:
|
||||
vars_list = vars
|
||||
|
||||
self._assemble_draws()
|
||||
cols = []
|
||||
if vars is not None:
|
||||
for var in dict.fromkeys(vars_list):
|
||||
if var in self._metadata.method_vars:
|
||||
cols.append(var)
|
||||
elif var in self._metadata.stan_vars:
|
||||
info = self._metadata.stan_vars[var]
|
||||
cols.extend(
|
||||
self.column_names[info.start_idx : info.end_idx]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown variable: {var}')
|
||||
|
||||
else:
|
||||
cols = list(self.column_names)
|
||||
|
||||
return pd.DataFrame(self._draws, columns=self.column_names)[cols]
|
||||
|
||||
def draws_xr(
|
||||
self,
|
||||
vars: Union[str, List[str], None] = None,
|
||||
) -> "xr.Dataset":
|
||||
"""
|
||||
Returns the sampler draws as a xarray Dataset.
|
||||
|
||||
:param vars: optional list of variable names.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMCMC.draws_xr
|
||||
CmdStanGQ.draws_xr
|
||||
"""
|
||||
if not XARRAY_INSTALLED:
|
||||
raise RuntimeError(
|
||||
'Package "xarray" is not installed, cannot produce draws array.'
|
||||
)
|
||||
|
||||
if vars is None:
|
||||
vars_list = list(self._metadata.stan_vars.keys())
|
||||
elif isinstance(vars, str):
|
||||
vars_list = [vars]
|
||||
else:
|
||||
vars_list = vars
|
||||
|
||||
self._assemble_draws()
|
||||
|
||||
meta = self._metadata.cmdstan_config
|
||||
attrs: MutableMapping[Hashable, Any] = {
|
||||
"stan_version": f"{meta['stan_version_major']}."
|
||||
f"{meta['stan_version_minor']}.{meta['stan_version_patch']}",
|
||||
"model": meta["model"],
|
||||
}
|
||||
|
||||
data: MutableMapping[Hashable, Any] = {}
|
||||
coordinates: MutableMapping[Hashable, Any] = {
|
||||
"draw": np.arange(self._draws.shape[0]),
|
||||
}
|
||||
|
||||
for var in vars_list:
|
||||
build_xarray_data(
|
||||
data,
|
||||
self._metadata.stan_vars[var],
|
||||
self._draws[:, np.newaxis, :],
|
||||
)
|
||||
return (
|
||||
xr.Dataset(data, coords=coordinates, attrs=attrs)
|
||||
.transpose('draw', ...)
|
||||
.squeeze()
|
||||
)
|
||||
|
||||
@property
|
||||
def mode(self) -> CmdStanMLE:
|
||||
"""
|
||||
Return the maximum a posteriori estimate (mode)
|
||||
as a :class:`CmdStanMLE` object.
|
||||
"""
|
||||
return self._mode
|
||||
|
||||
@property
|
||||
def metadata(self) -> InferenceMetadata:
|
||||
"""
|
||||
Returns object which contains CmdStan configuration as well as
|
||||
information about the names and structure of the inference method
|
||||
and model output variables.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
def __repr__(self) -> str:
|
||||
mode = '\n'.join(
|
||||
['\t' + line for line in repr(self.mode).splitlines()]
|
||||
)[1:]
|
||||
rep = 'CmdStanLaplace: model={} \nmode=({})\n{}'.format(
|
||||
self._runset.model,
|
||||
mode,
|
||||
self._runset._args.method_args.compose(0, cmd=[]),
|
||||
)
|
||||
rep = '{}\n csv_files:\n\t{}\n output_files:\n\t{}'.format(
|
||||
rep,
|
||||
'\n\t'.join(self._runset.csv_files),
|
||||
'\n\t'.join(self._runset.stdout_files),
|
||||
)
|
||||
return rep
|
||||
|
||||
def __getattr__(self, attr: str) -> np.ndarray:
|
||||
"""Synonymous with ``fit.stan_variable(attr)"""
|
||||
if attr.startswith("_"):
|
||||
raise AttributeError(f"Unknown variable name {attr}")
|
||||
try:
|
||||
return self.stan_variable(attr)
|
||||
except ValueError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise AttributeError(*e.args)
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
# This function returns the mapping of objects to serialize with pickle.
|
||||
# See https://docs.python.org/3/library/pickle.html#object.__getstate__
|
||||
# for details. We call _assemble_draws to ensure posterior samples have
|
||||
# been loaded prior to serialization.
|
||||
self._assemble_draws()
|
||||
return self.__dict__
|
||||
|
||||
@property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
"""
|
||||
Names of all outputs from the sampler, comprising sampler parameters
|
||||
and all components of all model parameters, transformed parameters,
|
||||
and quantities of interest. Corresponds to Stan CSV file header row,
|
||||
with names munged to array notation, e.g. `beta[1]` not `beta.1`.
|
||||
"""
|
||||
return self._metadata.cmdstan_config['column_names'] # type: ignore
|
||||
|
||||
def save_csvfiles(self, dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Move output CSV files to specified directory. If files were
|
||||
written to the temporary session directory, clean filename.
|
||||
E.g., save 'bernoulli-201912081451-1-5nm6as7u.csv' as
|
||||
'bernoulli-201912081451-1.csv'.
|
||||
|
||||
:param dir: directory path
|
||||
|
||||
See Also
|
||||
--------
|
||||
stanfit.RunSet.save_csvfiles
|
||||
cmdstanpy.from_csv
|
||||
"""
|
||||
self._runset.save_csvfiles(dir)
|
||||
826
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/mcmc.py
Normal file
826
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/mcmc.py
Normal file
@ -0,0 +1,826 @@
|
||||
"""
|
||||
Container for the result of running the sample (MCMC) method
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from io import StringIO
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import xarray as xr
|
||||
|
||||
XARRAY_INSTALLED = True
|
||||
except ImportError:
|
||||
XARRAY_INSTALLED = False
|
||||
|
||||
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP, _TMPDIR
|
||||
from cmdstanpy.cmdstan_args import Method, SamplerArgs
|
||||
from cmdstanpy.utils import (
|
||||
EXTENSION,
|
||||
build_xarray_data,
|
||||
check_sampler_csv,
|
||||
cmdstan_path,
|
||||
cmdstan_version_before,
|
||||
create_named_text_file,
|
||||
do_command,
|
||||
flatten_chains,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
from .metadata import InferenceMetadata
|
||||
from .runset import RunSet
|
||||
|
||||
|
||||
class CmdStanMCMC:
|
||||
"""
|
||||
Container for outputs from CmdStan sampler run.
|
||||
Provides methods to summarize and diagnose the model fit
|
||||
and accessor methods to access the entire sample or
|
||||
individual items. Created by :meth:`CmdStanModel.sample`
|
||||
|
||||
The sample is lazily instantiated on first access of either
|
||||
the resulting sample or the HMC tuning parameters, i.e., the
|
||||
step size and metric.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
def __init__(
|
||||
self,
|
||||
runset: RunSet,
|
||||
) -> None:
|
||||
"""Initialize object."""
|
||||
if not runset.method == Method.SAMPLE:
|
||||
raise ValueError(
|
||||
'Wrong runset method, expecting sample runset, '
|
||||
'found method {}'.format(runset.method)
|
||||
)
|
||||
self.runset = runset
|
||||
|
||||
# info from runset to be exposed
|
||||
sampler_args = self.runset._args.method_args
|
||||
assert isinstance(
|
||||
sampler_args, SamplerArgs
|
||||
) # make the typechecker happy
|
||||
self._iter_sampling: int = _CMDSTAN_SAMPLING
|
||||
if sampler_args.iter_sampling is not None:
|
||||
self._iter_sampling = sampler_args.iter_sampling
|
||||
self._iter_warmup: int = _CMDSTAN_WARMUP
|
||||
if sampler_args.iter_warmup is not None:
|
||||
self._iter_warmup = sampler_args.iter_warmup
|
||||
self._thin: int = _CMDSTAN_THIN
|
||||
if sampler_args.thin is not None:
|
||||
self._thin = sampler_args.thin
|
||||
self._is_fixed_param = sampler_args.fixed_param
|
||||
self._save_warmup: bool = sampler_args.save_warmup
|
||||
self._sig_figs = runset._args.sig_figs
|
||||
|
||||
# info from CSV values, instantiated lazily
|
||||
self._draws: np.ndarray = np.array(())
|
||||
# only valid when not is_fixed_param
|
||||
self._metric: np.ndarray = np.array(())
|
||||
self._step_size: np.ndarray = np.array(())
|
||||
self._divergences: np.ndarray = np.zeros(self.runset.chains, dtype=int)
|
||||
self._max_treedepths: np.ndarray = np.zeros(
|
||||
self.runset.chains, dtype=int
|
||||
)
|
||||
|
||||
# info from CSV initial comments and header
|
||||
config = self._validate_csv_files()
|
||||
self._metadata: InferenceMetadata = InferenceMetadata(config)
|
||||
if not self._is_fixed_param:
|
||||
self._check_sampler_diagnostics()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr = 'CmdStanMCMC: model={} chains={}{}'.format(
|
||||
self.runset.model,
|
||||
self.runset.chains,
|
||||
self.runset._args.method_args.compose(0, cmd=[]),
|
||||
)
|
||||
repr = '{}\n csv_files:\n\t{}\n output_files:\n\t{}'.format(
|
||||
repr,
|
||||
'\n\t'.join(self.runset.csv_files),
|
||||
'\n\t'.join(self.runset.stdout_files),
|
||||
)
|
||||
# TODO - hamiltonian, profiling files
|
||||
return repr
|
||||
|
||||
def __getattr__(self, attr: str) -> np.ndarray:
|
||||
"""Synonymous with ``fit.stan_variable(attr)"""
|
||||
if attr.startswith("_"):
|
||||
raise AttributeError(f"Unknown variable name {attr}")
|
||||
try:
|
||||
return self.stan_variable(attr)
|
||||
except ValueError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise AttributeError(*e.args)
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
# This function returns the mapping of objects to serialize with pickle.
|
||||
# See https://docs.python.org/3/library/pickle.html#object.__getstate__
|
||||
# for details. We call _assemble_draws to ensure posterior samples have
|
||||
# been loaded prior to serialization.
|
||||
self._assemble_draws()
|
||||
return self.__dict__
|
||||
|
||||
@property
|
||||
def chains(self) -> int:
|
||||
"""Number of chains."""
|
||||
return self.runset.chains
|
||||
|
||||
@property
|
||||
def chain_ids(self) -> List[int]:
|
||||
"""Chain ids."""
|
||||
return self.runset.chain_ids
|
||||
|
||||
@property
|
||||
def num_draws_warmup(self) -> int:
|
||||
"""Number of warmup draws per chain, i.e., thinned warmup iterations."""
|
||||
return int(math.ceil((self._iter_warmup) / self._thin))
|
||||
|
||||
@property
|
||||
def num_draws_sampling(self) -> int:
|
||||
"""
|
||||
Number of sampling (post-warmup) draws per chain, i.e.,
|
||||
thinned sampling iterations.
|
||||
"""
|
||||
return int(math.ceil((self._iter_sampling) / self._thin))
|
||||
|
||||
@property
|
||||
def metadata(self) -> InferenceMetadata:
|
||||
"""
|
||||
Returns object which contains CmdStan configuration as well as
|
||||
information about the names and structure of the inference method
|
||||
and model output variables.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
"""
|
||||
Names of all outputs from the sampler, comprising sampler parameters
|
||||
and all components of all model parameters, transformed parameters,
|
||||
and quantities of interest. Corresponds to Stan CSV file header row,
|
||||
with names munged to array notation, e.g. `beta[1]` not `beta.1`.
|
||||
"""
|
||||
return self._metadata.cmdstan_config['column_names'] # type: ignore
|
||||
|
||||
@property
|
||||
def metric_type(self) -> Optional[str]:
|
||||
"""
|
||||
Metric type used for adaptation, either 'diag_e' or 'dense_e', according
|
||||
to CmdStan arg 'metric'.
|
||||
When sampler algorithm 'fixed_param' is specified, metric_type is None.
|
||||
"""
|
||||
return (
|
||||
self._metadata.cmdstan_config['metric']
|
||||
if not self._is_fixed_param
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def metric(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Metric used by sampler for each chain.
|
||||
When sampler algorithm 'fixed_param' is specified, metric is None.
|
||||
"""
|
||||
if self._is_fixed_param:
|
||||
return None
|
||||
if self._metadata.cmdstan_config['metric'] == 'unit_e':
|
||||
get_logger().info(
|
||||
'Unit diagnonal metric, inverse mass matrix size unknown.'
|
||||
)
|
||||
return None
|
||||
self._assemble_draws()
|
||||
return self._metric
|
||||
|
||||
@property
|
||||
def step_size(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Step size used by sampler for each chain.
|
||||
When sampler algorithm 'fixed_param' is specified, step size is None.
|
||||
"""
|
||||
self._assemble_draws()
|
||||
return self._step_size if not self._is_fixed_param else None
|
||||
|
||||
@property
|
||||
def thin(self) -> int:
|
||||
"""
|
||||
Period between recorded iterations. (Default is 1).
|
||||
"""
|
||||
return self._thin
|
||||
|
||||
@property
|
||||
def divergences(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Per-chain total number of post-warmup divergent iterations.
|
||||
When sampler algorithm 'fixed_param' is specified, returns None.
|
||||
"""
|
||||
return self._divergences if not self._is_fixed_param else None
|
||||
|
||||
@property
|
||||
def max_treedepths(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Per-chain total number of post-warmup iterations where the NUTS sampler
|
||||
reached the maximum allowed treedepth.
|
||||
When sampler algorithm 'fixed_param' is specified, returns None.
|
||||
"""
|
||||
return self._max_treedepths if not self._is_fixed_param else None
|
||||
|
||||
def draws(
|
||||
self, *, inc_warmup: bool = False, concat_chains: bool = False
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Returns a numpy.ndarray over all draws from all chains which is
|
||||
stored column major so that the values for a parameter are contiguous
|
||||
in memory, likewise all draws from a chain are contiguous.
|
||||
By default, returns a 3D array arranged (draws, chains, columns);
|
||||
parameter ``concat_chains=True`` will return a 2D array where all
|
||||
chains are flattened into a single column, preserving chain order,
|
||||
so that given M chains of N draws, the first N draws are from chain 1,
|
||||
up through the last N draws from chain M.
|
||||
|
||||
:param inc_warmup: When ``True`` and the warmup draws are present in
|
||||
the output, i.e., the sampler was run with ``save_warmup=True``,
|
||||
then the warmup draws are included. Default value is ``False``.
|
||||
|
||||
:param concat_chains: When ``True`` return a 2D array flattening all
|
||||
all draws from all chains. Default value is ``False``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMCMC.draws_pd
|
||||
CmdStanMCMC.draws_xr
|
||||
CmdStanGQ.draws
|
||||
"""
|
||||
self._assemble_draws()
|
||||
|
||||
if inc_warmup and not self._save_warmup:
|
||||
get_logger().warning(
|
||||
"Sample doesn't contain draws from warmup iterations,"
|
||||
' rerun sampler with "save_warmup=True".'
|
||||
)
|
||||
|
||||
start_idx = 0
|
||||
if not inc_warmup and self._save_warmup:
|
||||
start_idx = self.num_draws_warmup
|
||||
|
||||
if concat_chains:
|
||||
return flatten_chains(self._draws[start_idx:, :, :])
|
||||
return self._draws[start_idx:, :, :]
|
||||
|
||||
def _validate_csv_files(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks that Stan CSV output files for all chains are consistent
|
||||
and returns dict containing config and column names.
|
||||
|
||||
Tabulates sampling iters which are divergent or at max treedepth
|
||||
Raises exception when inconsistencies detected.
|
||||
"""
|
||||
dzero = {}
|
||||
for i in range(self.chains):
|
||||
if i == 0:
|
||||
dzero = check_sampler_csv(
|
||||
path=self.runset.csv_files[i],
|
||||
is_fixed_param=self._is_fixed_param,
|
||||
iter_sampling=self._iter_sampling,
|
||||
iter_warmup=self._iter_warmup,
|
||||
save_warmup=self._save_warmup,
|
||||
thin=self._thin,
|
||||
)
|
||||
if not self._is_fixed_param:
|
||||
self._divergences[i] = dzero['ct_divergences']
|
||||
self._max_treedepths[i] = dzero['ct_max_treedepth']
|
||||
else:
|
||||
drest = check_sampler_csv(
|
||||
path=self.runset.csv_files[i],
|
||||
is_fixed_param=self._is_fixed_param,
|
||||
iter_sampling=self._iter_sampling,
|
||||
iter_warmup=self._iter_warmup,
|
||||
save_warmup=self._save_warmup,
|
||||
thin=self._thin,
|
||||
)
|
||||
for key in dzero:
|
||||
# check args that matter for parsing, plus name, version
|
||||
if (
|
||||
key
|
||||
in [
|
||||
'stan_version_major',
|
||||
'stan_version_minor',
|
||||
'stan_version_patch',
|
||||
'stanc_version',
|
||||
'model',
|
||||
'num_samples',
|
||||
'num_warmup',
|
||||
'save_warmup',
|
||||
'thin',
|
||||
'refresh',
|
||||
]
|
||||
and dzero[key] != drest[key]
|
||||
):
|
||||
raise ValueError(
|
||||
'CmdStan config mismatch in Stan CSV file {}: '
|
||||
'arg {} is {}, expected {}'.format(
|
||||
self.runset.csv_files[i],
|
||||
key,
|
||||
dzero[key],
|
||||
drest[key],
|
||||
)
|
||||
)
|
||||
if not self._is_fixed_param:
|
||||
self._divergences[i] = drest['ct_divergences']
|
||||
self._max_treedepths[i] = drest['ct_max_treedepth']
|
||||
return dzero
|
||||
|
||||
def _check_sampler_diagnostics(self) -> None:
|
||||
"""
|
||||
Warn if any iterations ended in divergences or hit maxtreedepth.
|
||||
"""
|
||||
if np.any(self._divergences) or np.any(self._max_treedepths):
|
||||
diagnostics = ['Some chains may have failed to converge.']
|
||||
ct_iters = self._metadata.cmdstan_config['num_samples']
|
||||
for i in range(self.runset._chains):
|
||||
if self._divergences[i] > 0:
|
||||
diagnostics.append(
|
||||
f'Chain {i + 1} had {self._divergences[i]} '
|
||||
'divergent transitions '
|
||||
f'({((self._divergences[i]/ct_iters)*100):.1f}%)'
|
||||
)
|
||||
if self._max_treedepths[i] > 0:
|
||||
diagnostics.append(
|
||||
f'Chain {i + 1} had {self._max_treedepths[i]} '
|
||||
'iterations at max treedepth '
|
||||
f'({((self._max_treedepths[i]/ct_iters)*100):.1f}%)'
|
||||
)
|
||||
diagnostics.append(
|
||||
'Use the "diagnose()" method on the CmdStanMCMC object'
|
||||
' to see further information.'
|
||||
)
|
||||
get_logger().warning('\n\t'.join(diagnostics))
|
||||
|
||||
def _assemble_draws(self) -> None:
|
||||
"""
|
||||
Allocates and populates the step size, metric, and sample arrays
|
||||
by parsing the validated stan_csv files.
|
||||
"""
|
||||
if self._draws.shape != (0,):
|
||||
return
|
||||
num_draws = self.num_draws_sampling
|
||||
sampling_iter_start = 0
|
||||
if self._save_warmup:
|
||||
num_draws += self.num_draws_warmup
|
||||
sampling_iter_start = self.num_draws_warmup
|
||||
self._draws = np.empty(
|
||||
(num_draws, self.chains, len(self.column_names)),
|
||||
dtype=float,
|
||||
order='F',
|
||||
)
|
||||
self._step_size = np.empty(self.chains, dtype=float)
|
||||
for chain in range(self.chains):
|
||||
with open(self.runset.csv_files[chain], 'r') as fd:
|
||||
line = fd.readline().strip()
|
||||
# read initial comments, CSV header row
|
||||
while len(line) > 0 and line.startswith('#'):
|
||||
line = fd.readline().strip()
|
||||
if not self._is_fixed_param:
|
||||
# handle warmup draws, if any
|
||||
if self._save_warmup:
|
||||
for i in range(self.num_draws_warmup):
|
||||
line = fd.readline().strip()
|
||||
xs = line.split(',')
|
||||
self._draws[i, chain, :] = [float(x) for x in xs]
|
||||
line = fd.readline().strip()
|
||||
if line != '# Adaptation terminated': # shouldn't happen?
|
||||
while line != '# Adaptation terminated':
|
||||
line = fd.readline().strip()
|
||||
# step_size, metric (diag_e and dense_e only)
|
||||
line = fd.readline().strip()
|
||||
_, step_size = line.split('=')
|
||||
self._step_size[chain] = float(step_size.strip())
|
||||
if self._metadata.cmdstan_config['metric'] != 'unit_e':
|
||||
line = fd.readline().strip() # metric type
|
||||
line = fd.readline().lstrip(' #\t').rstrip()
|
||||
num_unconstrained_params = len(line.split(','))
|
||||
if chain == 0: # can't allocate w/o num params
|
||||
if self.metric_type == 'diag_e':
|
||||
self._metric = np.empty(
|
||||
(self.chains, num_unconstrained_params),
|
||||
dtype=float,
|
||||
)
|
||||
else:
|
||||
self._metric = np.empty(
|
||||
(
|
||||
self.chains,
|
||||
num_unconstrained_params,
|
||||
num_unconstrained_params,
|
||||
),
|
||||
dtype=float,
|
||||
)
|
||||
if line:
|
||||
if self.metric_type == 'diag_e':
|
||||
xs = line.split(',')
|
||||
self._metric[chain, :] = [float(x) for x in xs]
|
||||
else:
|
||||
xs = line.strip().split(',')
|
||||
self._metric[chain, 0, :] = [
|
||||
float(x) for x in xs
|
||||
]
|
||||
for i in range(1, num_unconstrained_params):
|
||||
line = fd.readline().lstrip(' #\t').rstrip()
|
||||
xs = line.split(',')
|
||||
self._metric[chain, i, :] = [
|
||||
float(x) for x in xs
|
||||
]
|
||||
else: # unit_e changed in 2.34 to have an extra line
|
||||
pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
if not line.startswith('#'):
|
||||
fd.seek(pos)
|
||||
|
||||
# process draws
|
||||
for i in range(sampling_iter_start, num_draws):
|
||||
line = fd.readline().strip()
|
||||
xs = line.split(',')
|
||||
self._draws[i, chain, :] = [float(x) for x in xs]
|
||||
assert self._draws is not None
|
||||
|
||||
def summary(
|
||||
self,
|
||||
percentiles: Sequence[int] = (5, 50, 95),
|
||||
sig_figs: int = 6,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Run cmdstan/bin/stansummary over all output CSV files, assemble
|
||||
summary into DataFrame object. The first row contains statistics
|
||||
for the total joint log probability `lp__`, but is omitted when the
|
||||
Stan model has no parameters. The remaining rows contain summary
|
||||
statistics for all parameters, transformed parameters, and generated
|
||||
quantities variables, in program declaration order.
|
||||
|
||||
:param percentiles: Ordered non-empty sequence of percentiles to report.
|
||||
Must be integers from (1, 99), inclusive. Defaults to
|
||||
``(5, 50, 95)``
|
||||
|
||||
:param sig_figs: Number of significant figures to report.
|
||||
Must be an integer between 1 and 18. If unspecified, the default
|
||||
precision for the system file I/O is used; the usual value is 6.
|
||||
If precision above 6 is requested, sample must have been produced
|
||||
by CmdStan version 2.25 or later and sampler output precision
|
||||
must equal to or greater than the requested summary precision.
|
||||
|
||||
:return: pandas.DataFrame
|
||||
"""
|
||||
if len(percentiles) == 0:
|
||||
raise ValueError(
|
||||
'Invalid percentiles argument, must be ordered'
|
||||
' non-empty list from (1, 99), inclusive.'
|
||||
)
|
||||
cur_pct = 0
|
||||
for pct in percentiles:
|
||||
if pct > 99 or not pct > cur_pct:
|
||||
raise ValueError(
|
||||
'Invalid percentiles spec, must be ordered'
|
||||
' non-empty list from (1, 99), inclusive.'
|
||||
)
|
||||
cur_pct = pct
|
||||
percentiles_str = (
|
||||
f"--percentiles= {','.join(str(x) for x in percentiles)}"
|
||||
)
|
||||
|
||||
if not isinstance(sig_figs, int) or sig_figs < 1 or sig_figs > 18:
|
||||
raise ValueError(
|
||||
'Keyword "sig_figs" must be an integer between 1 and 18,'
|
||||
' found {}'.format(sig_figs)
|
||||
)
|
||||
csv_sig_figs = self._sig_figs or 6
|
||||
if sig_figs > csv_sig_figs:
|
||||
get_logger().warning(
|
||||
'Requesting %d significant digits of output, but CSV files'
|
||||
' only have %d digits of precision.',
|
||||
sig_figs,
|
||||
csv_sig_figs,
|
||||
)
|
||||
sig_figs_str = f'--sig_figs={sig_figs}'
|
||||
cmd_path = os.path.join(
|
||||
cmdstan_path(), 'bin', 'stansummary' + EXTENSION
|
||||
)
|
||||
tmp_csv_file = 'stansummary-{}-'.format(self.runset._args.model_name)
|
||||
tmp_csv_path = create_named_text_file(
|
||||
dir=_TMPDIR, prefix=tmp_csv_file, suffix='.csv', name_only=True
|
||||
)
|
||||
csv_str = '--csv_filename={}'.format(tmp_csv_path)
|
||||
# TODO: remove at some future release
|
||||
if cmdstan_version_before(2, 24):
|
||||
csv_str = '--csv_file={}'.format(tmp_csv_path)
|
||||
cmd = [
|
||||
cmd_path,
|
||||
percentiles_str,
|
||||
sig_figs_str,
|
||||
csv_str,
|
||||
] + self.runset.csv_files
|
||||
do_command(cmd, fd_out=None)
|
||||
with open(tmp_csv_path, 'rb') as fd:
|
||||
summary_data = pd.read_csv(
|
||||
fd,
|
||||
delimiter=',',
|
||||
header=0,
|
||||
index_col=0,
|
||||
comment='#',
|
||||
float_precision='high',
|
||||
)
|
||||
mask = (
|
||||
[not x.endswith('__') for x in summary_data.index]
|
||||
if self._is_fixed_param
|
||||
else [
|
||||
x == 'lp__' or not x.endswith('__') for x in summary_data.index
|
||||
]
|
||||
)
|
||||
summary_data.index.name = None
|
||||
return summary_data[mask]
|
||||
|
||||
def diagnose(self) -> Optional[str]:
|
||||
"""
|
||||
Run cmdstan/bin/diagnose over all output CSV files,
|
||||
return console output.
|
||||
|
||||
The diagnose utility reads the outputs of all chains
|
||||
and checks for the following potential problems:
|
||||
|
||||
+ Transitions that hit the maximum treedepth
|
||||
+ Divergent transitions
|
||||
+ Low E-BFMI values (sampler transitions HMC potential energy)
|
||||
+ Low effective sample sizes
|
||||
+ High R-hat values
|
||||
"""
|
||||
cmd_path = os.path.join(cmdstan_path(), 'bin', 'diagnose' + EXTENSION)
|
||||
cmd = [cmd_path] + self.runset.csv_files
|
||||
result = StringIO()
|
||||
do_command(cmd=cmd, fd_out=result)
|
||||
return result.getvalue()
|
||||
|
||||
def draws_pd(
|
||||
self,
|
||||
vars: Union[List[str], str, None] = None,
|
||||
inc_warmup: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Returns the sample draws as a pandas DataFrame.
|
||||
Flattens all chains into single column. Container variables
|
||||
(array, vector, matrix) will span multiple columns, one column
|
||||
per element. E.g. variable 'matrix[2,2] foo' spans 4 columns:
|
||||
'foo[1,1], ... foo[2,2]'.
|
||||
|
||||
:param vars: optional list of variable names.
|
||||
|
||||
:param inc_warmup: When ``True`` and the warmup draws are present in
|
||||
the output, i.e., the sampler was run with ``save_warmup=True``,
|
||||
then the warmup draws are included. Default value is ``False``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMCMC.draws
|
||||
CmdStanMCMC.draws_xr
|
||||
CmdStanGQ.draws_pd
|
||||
"""
|
||||
if vars is not None:
|
||||
if isinstance(vars, str):
|
||||
vars_list = [vars]
|
||||
else:
|
||||
vars_list = vars
|
||||
|
||||
if inc_warmup and not self._save_warmup:
|
||||
get_logger().warning(
|
||||
'Draws from warmup iterations not available,'
|
||||
' must run sampler with "save_warmup=True".'
|
||||
)
|
||||
|
||||
self._assemble_draws()
|
||||
cols = []
|
||||
if vars is not None:
|
||||
for var in dict.fromkeys(vars_list):
|
||||
if var in self._metadata.method_vars:
|
||||
cols.append(var)
|
||||
elif var in self._metadata.stan_vars:
|
||||
info = self._metadata.stan_vars[var]
|
||||
cols.extend(
|
||||
self.column_names[info.start_idx : info.end_idx]
|
||||
)
|
||||
elif var in ['chain__', 'iter__', 'draw__']:
|
||||
cols.append(var)
|
||||
else:
|
||||
raise ValueError(f'Unknown variable: {var}')
|
||||
else:
|
||||
cols = ['chain__', 'iter__', 'draw__'] + list(self.column_names)
|
||||
|
||||
draws = self.draws(inc_warmup=inc_warmup)
|
||||
# add long-form columns for chain, iteration, draw
|
||||
n_draws, n_chains, _ = draws.shape
|
||||
chains_col = (
|
||||
np.repeat(np.arange(1, n_chains + 1), n_draws)
|
||||
.reshape(1, n_chains, n_draws)
|
||||
.T
|
||||
)
|
||||
iter_col = (
|
||||
np.tile(np.arange(1, n_draws + 1), n_chains)
|
||||
.reshape(1, n_chains, n_draws)
|
||||
.T
|
||||
)
|
||||
draw_col = (
|
||||
np.arange(1, (n_draws * n_chains) + 1)
|
||||
.reshape(1, n_chains, n_draws)
|
||||
.T
|
||||
)
|
||||
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
|
||||
|
||||
return pd.DataFrame(
|
||||
data=flatten_chains(draws),
|
||||
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),
|
||||
)[cols]
|
||||
|
||||
def draws_xr(
|
||||
self, vars: Union[str, List[str], None] = None, inc_warmup: bool = False
|
||||
) -> "xr.Dataset":
|
||||
"""
|
||||
Returns the sampler draws as a xarray Dataset.
|
||||
|
||||
:param vars: optional list of variable names.
|
||||
|
||||
:param inc_warmup: When ``True`` and the warmup draws are present in
|
||||
the output, i.e., the sampler was run with ``save_warmup=True``,
|
||||
then the warmup draws are included. Default value is ``False``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMCMC.draws
|
||||
CmdStanMCMC.draws_pd
|
||||
CmdStanGQ.draws_xr
|
||||
"""
|
||||
if not XARRAY_INSTALLED:
|
||||
raise RuntimeError(
|
||||
'Package "xarray" is not installed, cannot produce draws array.'
|
||||
)
|
||||
if inc_warmup and not self._save_warmup:
|
||||
get_logger().warning(
|
||||
"Draws from warmup iterations not available,"
|
||||
' must run sampler with "save_warmup=True".'
|
||||
)
|
||||
if vars is None:
|
||||
vars_list = list(self._metadata.stan_vars.keys())
|
||||
elif isinstance(vars, str):
|
||||
vars_list = [vars]
|
||||
else:
|
||||
vars_list = vars
|
||||
|
||||
self._assemble_draws()
|
||||
|
||||
num_draws = self.num_draws_sampling
|
||||
meta = self._metadata.cmdstan_config
|
||||
attrs: MutableMapping[Hashable, Any] = {
|
||||
"stan_version": f"{meta['stan_version_major']}."
|
||||
f"{meta['stan_version_minor']}.{meta['stan_version_patch']}",
|
||||
"model": meta["model"],
|
||||
"num_draws_sampling": num_draws,
|
||||
}
|
||||
if inc_warmup and self._save_warmup:
|
||||
num_draws += self.num_draws_warmup
|
||||
attrs["num_draws_warmup"] = self.num_draws_warmup
|
||||
|
||||
data: MutableMapping[Hashable, Any] = {}
|
||||
coordinates: MutableMapping[Hashable, Any] = {
|
||||
"chain": self.chain_ids,
|
||||
"draw": np.arange(num_draws),
|
||||
}
|
||||
|
||||
for var in vars_list:
|
||||
build_xarray_data(
|
||||
data,
|
||||
self._metadata.stan_vars[var],
|
||||
self.draws(inc_warmup=inc_warmup),
|
||||
)
|
||||
return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
|
||||
'chain', 'draw', ...
|
||||
)
|
||||
|
||||
def stan_variable(
|
||||
self,
|
||||
var: str,
|
||||
inc_warmup: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Return a numpy.ndarray which contains the set of draws
|
||||
for the named Stan program variable. Flattens the chains,
|
||||
leaving the draws in chain order. The first array dimension,
|
||||
corresponds to number of draws or post-warmup draws in the sample,
|
||||
per argument ``inc_warmup``. The remaining dimensions correspond to
|
||||
the shape of the Stan program variable.
|
||||
|
||||
Underlyingly draws are in chain order, i.e., for a sample with
|
||||
N chains of M draws each, the first M array elements are from chain 1,
|
||||
the next M are from chain 2, and the last M elements are from chain N.
|
||||
|
||||
* If the variable is a scalar variable, the return array has shape
|
||||
( draws * chains, 1).
|
||||
* If the variable is a vector, the return array has shape
|
||||
( draws * chains, len(vector))
|
||||
* If the variable is a matrix, the return array has shape
|
||||
( draws * chains, size(dim 1), size(dim 2) )
|
||||
* If the variable is an array with N dimensions, the return array
|
||||
has shape ( draws * chains, size(dim 1), ..., size(dim N))
|
||||
|
||||
For example, if the Stan program variable ``theta`` is a 3x3 matrix,
|
||||
and the sample consists of 4 chains with 1000 post-warmup draws,
|
||||
this function will return a numpy.ndarray with shape (4000,3,3).
|
||||
|
||||
This functionaltiy is also available via a shortcut using ``.`` -
|
||||
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
|
||||
|
||||
:param var: variable name
|
||||
|
||||
:param inc_warmup: When ``True`` and the warmup draws are present in
|
||||
the output, i.e., the sampler was run with ``save_warmup=True``,
|
||||
then the warmup draws are included. Default value is ``False``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMCMC.stan_variables
|
||||
CmdStanMLE.stan_variable
|
||||
CmdStanPathfinder.stan_variable
|
||||
CmdStanVB.stan_variable
|
||||
CmdStanGQ.stan_variable
|
||||
CmdStanLaplace.stan_variable
|
||||
"""
|
||||
try:
|
||||
draws = self.draws(inc_warmup=inc_warmup, concat_chains=True)
|
||||
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
|
||||
draws
|
||||
)
|
||||
return out
|
||||
except KeyError:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise ValueError(
|
||||
f'Unknown variable name: {var}\n'
|
||||
'Available variables are '
|
||||
+ ", ".join(self._metadata.stan_vars.keys())
|
||||
)
|
||||
|
||||
def stan_variables(self) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Return a dictionary mapping Stan program variables names
|
||||
to the corresponding numpy.ndarray containing the inferred values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMCMC.stan_variable
|
||||
CmdStanMLE.stan_variables
|
||||
CmdStanPathfinder.stan_variables
|
||||
CmdStanVB.stan_variables
|
||||
CmdStanGQ.stan_variables
|
||||
CmdStanLaplace.stan_variables
|
||||
"""
|
||||
result = {}
|
||||
for name in self._metadata.stan_vars:
|
||||
result[name] = self.stan_variable(name)
|
||||
return result
|
||||
|
||||
def method_variables(self) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Returns a dictionary of all sampler variables, i.e., all
|
||||
output column names ending in `__`. Assumes that all variables
|
||||
are scalar variables where column name is variable name.
|
||||
Maps each column name to a numpy.ndarray (draws x chains x 1)
|
||||
containing per-draw diagnostic values.
|
||||
"""
|
||||
self._assemble_draws()
|
||||
return {
|
||||
name: var.extract_reshape(self._draws)
|
||||
for name, var in self._metadata.method_vars.items()
|
||||
}
|
||||
|
||||
def save_csvfiles(self, dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Move output CSV files to specified directory. If files were
|
||||
written to the temporary session directory, clean filename.
|
||||
E.g., save 'bernoulli-201912081451-1-5nm6as7u.csv' as
|
||||
'bernoulli-201912081451-1.csv'.
|
||||
|
||||
:param dir: directory path
|
||||
|
||||
See Also
|
||||
--------
|
||||
stanfit.RunSet.save_csvfiles
|
||||
cmdstanpy.from_csv
|
||||
"""
|
||||
self.runset.save_csvfiles(dir)
|
||||
@ -0,0 +1,53 @@
|
||||
"""Container for metadata parsed from the output of a CmdStan run"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import stanio
|
||||
|
||||
|
||||
class InferenceMetadata:
|
||||
"""
|
||||
CmdStan configuration and contents of output file parsed out of
|
||||
the Stan CSV file header comments and column headers.
|
||||
Assumes valid CSV files.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
"""Initialize object from CSV headers"""
|
||||
self._cmdstan_config = config
|
||||
vars = stanio.parse_header(config['raw_header'])
|
||||
|
||||
self._method_vars = {
|
||||
k: v for (k, v) in vars.items() if k.endswith('__')
|
||||
}
|
||||
self._stan_vars = {
|
||||
k: v for (k, v) in vars.items() if not k.endswith('__')
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return 'Metadata:\n{}\n'.format(self._cmdstan_config)
|
||||
|
||||
@property
|
||||
def cmdstan_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary containing a set of name, value pairs
|
||||
parsed out of the Stan CSV file header. These include the
|
||||
command configuration and the CSV file header row information.
|
||||
Uses deepcopy for immutability.
|
||||
"""
|
||||
return copy.deepcopy(self._cmdstan_config)
|
||||
|
||||
@property
|
||||
def method_vars(self) -> Dict[str, stanio.Variable]:
|
||||
"""
|
||||
Method variable names always end in `__`, e.g. `lp__`.
|
||||
"""
|
||||
return self._method_vars
|
||||
|
||||
@property
|
||||
def stan_vars(self) -> Dict[str, stanio.Variable]:
|
||||
"""
|
||||
These are the user-defined variables in the Stan program.
|
||||
"""
|
||||
return self._stan_vars
|
||||
284
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/mle.py
Normal file
284
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/mle.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""Container for the result of running optimization"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from cmdstanpy.cmdstan_args import Method, OptimizeArgs
|
||||
from cmdstanpy.utils import get_logger, scan_optimize_csv
|
||||
|
||||
from .metadata import InferenceMetadata
|
||||
from .runset import RunSet
|
||||
|
||||
|
||||
class CmdStanMLE:
|
||||
"""
|
||||
Container for outputs from CmdStan optimization.
|
||||
Created by :meth:`CmdStanModel.optimize`.
|
||||
"""
|
||||
|
||||
def __init__(self, runset: RunSet) -> None:
|
||||
"""Initialize object."""
|
||||
if not runset.method == Method.OPTIMIZE:
|
||||
raise ValueError(
|
||||
'Wrong runset method, expecting optimize runset, '
|
||||
'found method {}'.format(runset.method)
|
||||
)
|
||||
self.runset = runset
|
||||
# info from runset to be exposed
|
||||
self.converged = runset._check_retcodes()
|
||||
optimize_args = self.runset._args.method_args
|
||||
assert isinstance(
|
||||
optimize_args, OptimizeArgs
|
||||
) # make the typechecker happy
|
||||
self._save_iterations: bool = optimize_args.save_iterations
|
||||
self._set_mle_attrs(runset.csv_files[0])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr = 'CmdStanMLE: model={}{}'.format(
|
||||
self.runset.model, self.runset._args.method_args.compose(0, cmd=[])
|
||||
)
|
||||
repr = '{}\n csv_file:\n\t{}\n output_file:\n\t{}'.format(
|
||||
repr,
|
||||
'\n\t'.join(self.runset.csv_files),
|
||||
'\n\t'.join(self.runset.stdout_files),
|
||||
)
|
||||
if not self.converged:
|
||||
repr = '{}\n Warning: invalid estimate, '.format(repr)
|
||||
repr = '{} optimization failed to converge.'.format(repr)
|
||||
return repr
|
||||
|
||||
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
|
||||
"""Synonymous with ``fit.stan_variable(attr)"""
|
||||
if attr.startswith("_"):
|
||||
raise AttributeError(f"Unknown variable name {attr}")
|
||||
try:
|
||||
return self.stan_variable(attr)
|
||||
except ValueError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise AttributeError(*e.args)
|
||||
|
||||
def _set_mle_attrs(self, sample_csv_0: str) -> None:
|
||||
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
|
||||
self._metadata = InferenceMetadata(meta)
|
||||
self._column_names: Tuple[str, ...] = meta['column_names']
|
||||
self._mle: np.ndarray = meta['mle']
|
||||
if self._save_iterations:
|
||||
self._all_iters: np.ndarray = meta['all_iters']
|
||||
|
||||
@property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
"""
|
||||
Names of estimated quantities, includes joint log probability,
|
||||
and all parameters, transformed parameters, and generated quantities.
|
||||
"""
|
||||
return self._column_names
|
||||
|
||||
@property
|
||||
def metadata(self) -> InferenceMetadata:
|
||||
"""
|
||||
Returns object which contains CmdStan configuration as well as
|
||||
information about the names and structure of the inference method
|
||||
and model output variables.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def optimized_params_np(self) -> np.ndarray:
|
||||
"""
|
||||
Returns all final estimates from the optimizer as a numpy.ndarray
|
||||
which contains all optimizer outputs, i.e., the value for `lp__`
|
||||
as well as all Stan program variables.
|
||||
"""
|
||||
if not self.converged:
|
||||
get_logger().warning(
|
||||
'Invalid estimate, optimization failed to converge.'
|
||||
)
|
||||
return self._mle
|
||||
|
||||
@property
|
||||
def optimized_iterations_np(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Returns all saved iterations from the optimizer and final estimate
|
||||
as a numpy.ndarray which contains all optimizer outputs, i.e.,
|
||||
the value for `lp__` as well as all Stan program variables.
|
||||
|
||||
"""
|
||||
if not self._save_iterations:
|
||||
get_logger().warning(
|
||||
'Intermediate iterations not saved to CSV output file. '
|
||||
'Rerun the optimize method with "save_iterations=True".'
|
||||
)
|
||||
return None
|
||||
if not self.converged:
|
||||
get_logger().warning(
|
||||
'Invalid estimate, optimization failed to converge.'
|
||||
)
|
||||
return self._all_iters
|
||||
|
||||
@property
|
||||
def optimized_params_pd(self) -> pd.DataFrame:
|
||||
"""
|
||||
Returns all final estimates from the optimizer as a pandas.DataFrame
|
||||
which contains all optimizer outputs, i.e., the value for `lp__`
|
||||
as well as all Stan program variables.
|
||||
"""
|
||||
if not self.runset._check_retcodes():
|
||||
get_logger().warning(
|
||||
'Invalid estimate, optimization failed to converge.'
|
||||
)
|
||||
return pd.DataFrame([self._mle], columns=self.column_names)
|
||||
|
||||
@property
|
||||
def optimized_iterations_pd(self) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Returns all saved iterations from the optimizer and final estimate
|
||||
as a pandas.DataFrame which contains all optimizer outputs, i.e.,
|
||||
the value for `lp__` as well as all Stan program variables.
|
||||
|
||||
"""
|
||||
if not self._save_iterations:
|
||||
get_logger().warning(
|
||||
'Intermediate iterations not saved to CSV output file. '
|
||||
'Rerun the optimize method with "save_iterations=True".'
|
||||
)
|
||||
return None
|
||||
if not self.converged:
|
||||
get_logger().warning(
|
||||
'Invalid estimate, optimization failed to converge.'
|
||||
)
|
||||
return pd.DataFrame(self._all_iters, columns=self.column_names)
|
||||
|
||||
@property
|
||||
def optimized_params_dict(self) -> Dict[str, np.float64]:
|
||||
"""
|
||||
Returns all estimates from the optimizer, including `lp__` as a
|
||||
Python Dict. Only returns estimate from final iteration.
|
||||
"""
|
||||
if not self.runset._check_retcodes():
|
||||
get_logger().warning(
|
||||
'Invalid estimate, optimization failed to converge.'
|
||||
)
|
||||
return OrderedDict(zip(self.column_names, self._mle))
|
||||
|
||||
def stan_variable(
|
||||
self,
|
||||
var: str,
|
||||
*,
|
||||
inc_iterations: bool = False,
|
||||
warn: bool = True,
|
||||
) -> Union[np.ndarray, float]:
|
||||
"""
|
||||
Return a numpy.ndarray which contains the estimates for the
|
||||
for the named Stan program variable where the dimensions of the
|
||||
numpy.ndarray match the shape of the Stan program variable.
|
||||
|
||||
This functionaltiy is also available via a shortcut using ``.`` -
|
||||
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
|
||||
|
||||
:param var: variable name
|
||||
|
||||
:param inc_iterations: When ``True`` and the intermediate estimates
|
||||
are included in the output, i.e., the optimizer was run with
|
||||
``save_iterations=True``, then intermediate estimates are included.
|
||||
Default value is ``False``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMLE.stan_variables
|
||||
CmdStanMCMC.stan_variable
|
||||
CmdStanPathfinder.stan_variable
|
||||
CmdStanVB.stan_variable
|
||||
CmdStanGQ.stan_variable
|
||||
CmdStanLaplace.stan_variable
|
||||
"""
|
||||
if var not in self._metadata.stan_vars:
|
||||
raise ValueError(
|
||||
f'Unknown variable name: {var}\n'
|
||||
'Available variables are ' + ", ".join(self._metadata.stan_vars)
|
||||
)
|
||||
if warn and inc_iterations and not self._save_iterations:
|
||||
get_logger().warning(
|
||||
'Intermediate iterations not saved to CSV output file. '
|
||||
'Rerun the optimize method with "save_iterations=True".'
|
||||
)
|
||||
if warn and not self.runset._check_retcodes():
|
||||
get_logger().warning(
|
||||
'Invalid estimate, optimization failed to converge.'
|
||||
)
|
||||
if inc_iterations and self._save_iterations:
|
||||
data = self._all_iters
|
||||
else:
|
||||
data = self._mle
|
||||
|
||||
try:
|
||||
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
|
||||
data
|
||||
)
|
||||
# TODO(2.0) remove
|
||||
if out.shape == () or out.shape == (1,):
|
||||
get_logger().warning(
|
||||
"The default behavior of CmdStanMLE.stan_variable() "
|
||||
"will change in a future release to always return a "
|
||||
"numpy.ndarray, even for scalar variables."
|
||||
)
|
||||
return out.item() # type: ignore
|
||||
return out
|
||||
except KeyError:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise ValueError(
|
||||
f'Unknown variable name: {var}\n'
|
||||
'Available variables are '
|
||||
+ ", ".join(self._metadata.stan_vars.keys())
|
||||
)
|
||||
|
||||
def stan_variables(
|
||||
self, inc_iterations: bool = False
|
||||
) -> Dict[str, Union[np.ndarray, float]]:
|
||||
"""
|
||||
Return a dictionary mapping Stan program variables names
|
||||
to the corresponding numpy.ndarray containing the inferred values.
|
||||
|
||||
:param inc_iterations: When ``True`` and the intermediate estimates
|
||||
are included in the output, i.e., the optimizer was run with
|
||||
``save_iterations=True``, then intermediate estimates are included.
|
||||
Default value is ``False``.
|
||||
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanMLE.stan_variable
|
||||
CmdStanMCMC.stan_variables
|
||||
CmdStanPathfinder.stan_variables
|
||||
CmdStanVB.stan_variables
|
||||
CmdStanGQ.stan_variables
|
||||
CmdStanLaplace.stan_variables
|
||||
"""
|
||||
if not self.runset._check_retcodes():
|
||||
get_logger().warning(
|
||||
'Invalid estimate, optimization failed to converge.'
|
||||
)
|
||||
result = {}
|
||||
for name in self._metadata.stan_vars:
|
||||
result[name] = self.stan_variable(
|
||||
name, inc_iterations=inc_iterations, warn=False
|
||||
)
|
||||
return result
|
||||
|
||||
def save_csvfiles(self, dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Move output CSV files to specified directory. If files were
|
||||
written to the temporary session directory, clean filename.
|
||||
E.g., save 'bernoulli-201912081451-1-5nm6as7u.csv' as
|
||||
'bernoulli-201912081451-1.csv'.
|
||||
|
||||
:param dir: directory path
|
||||
|
||||
See Also
|
||||
--------
|
||||
stanfit.RunSet.save_csvfiles
|
||||
cmdstanpy.from_csv
|
||||
"""
|
||||
self.runset.save_csvfiles(dir)
|
||||
@ -0,0 +1,237 @@
|
||||
"""
|
||||
Container for the result of running Pathfinder.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cmdstanpy.cmdstan_args import Method
|
||||
from cmdstanpy.stanfit.metadata import InferenceMetadata
|
||||
from cmdstanpy.stanfit.runset import RunSet
|
||||
from cmdstanpy.utils.stancsv import scan_generic_csv
|
||||
|
||||
|
||||
class CmdStanPathfinder:
|
||||
"""
|
||||
Container for outputs from the Pathfinder algorithm.
|
||||
Created by :meth:`CmdStanModel.pathfinder()`.
|
||||
"""
|
||||
|
||||
def __init__(self, runset: RunSet):
|
||||
"""Initialize object."""
|
||||
if not runset.method == Method.PATHFINDER:
|
||||
raise ValueError(
|
||||
'Wrong runset method, expecting Pathfinder runset, '
|
||||
'found method {}'.format(runset.method)
|
||||
)
|
||||
self._runset = runset
|
||||
|
||||
self._draws: np.ndarray = np.array(())
|
||||
|
||||
config = scan_generic_csv(runset.csv_files[0])
|
||||
self._metadata = InferenceMetadata(config)
|
||||
|
||||
def create_inits(
|
||||
self, seed: Optional[int] = None, chains: int = 4
|
||||
) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
|
||||
"""
|
||||
Create initial values for the parameters of the model
|
||||
by randomly selecting draws from the Pathfinder approximation.
|
||||
|
||||
:param seed: Used for random selection, defaults to None
|
||||
:param chains: Number of initial values to return, defaults to 4
|
||||
:return: The initial values for the parameters of the model.
|
||||
|
||||
If ``chains`` is 1, a dictionary is returned, otherwise a list
|
||||
of dictionaries is returned, in the format expected for the
|
||||
``inits`` argument. of :meth:`CmdStanModel.sample`.
|
||||
"""
|
||||
self._assemble_draws()
|
||||
rng = np.random.default_rng(seed)
|
||||
idxs = rng.choice(self._draws.shape[0], size=chains, replace=False)
|
||||
if chains == 1:
|
||||
draw = self._draws[idxs[0]]
|
||||
return {
|
||||
name: var.extract_reshape(draw)
|
||||
for name, var in self._metadata.stan_vars.items()
|
||||
}
|
||||
else:
|
||||
return [
|
||||
{
|
||||
name: var.extract_reshape(self._draws[idx])
|
||||
for name, var in self._metadata.stan_vars.items()
|
||||
}
|
||||
for idx in idxs
|
||||
]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = 'CmdStanPathfinder: model={}{}'.format(
|
||||
self._runset.model,
|
||||
self._runset._args.method_args.compose(0, cmd=[]),
|
||||
)
|
||||
rep = '{}\n csv_files:\n\t{}\n output_files:\n\t{}'.format(
|
||||
rep,
|
||||
'\n\t'.join(self._runset.csv_files),
|
||||
'\n\t'.join(self._runset.stdout_files),
|
||||
)
|
||||
return rep
|
||||
|
||||
# below this is identical to same functions in Laplace
|
||||
def _assemble_draws(self) -> None:
|
||||
if self._draws.shape != (0,):
|
||||
return
|
||||
|
||||
with open(self._runset.csv_files[0], 'r') as fd:
|
||||
while (fd.readline()).startswith("#"):
|
||||
pass
|
||||
self._draws = np.loadtxt(
|
||||
fd,
|
||||
dtype=float,
|
||||
ndmin=2,
|
||||
delimiter=',',
|
||||
comments="#",
|
||||
)
|
||||
|
||||
def stan_variable(self, var: str) -> np.ndarray:
|
||||
"""
|
||||
Return a numpy.ndarray which contains the estimates for the
|
||||
for the named Stan program variable where the dimensions of the
|
||||
numpy.ndarray match the shape of the Stan program variable.
|
||||
|
||||
This functionaltiy is also available via a shortcut using ``.`` -
|
||||
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
|
||||
|
||||
:param var: variable name
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanPathfinder.stan_variables
|
||||
CmdStanMLE.stan_variable
|
||||
CmdStanMCMC.stan_variable
|
||||
CmdStanVB.stan_variable
|
||||
CmdStanGQ.stan_variable
|
||||
CmdStanLaplace.stan_variable
|
||||
"""
|
||||
self._assemble_draws()
|
||||
try:
|
||||
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
|
||||
self._draws
|
||||
)
|
||||
return out
|
||||
except KeyError:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise ValueError(
|
||||
f'Unknown variable name: {var}\n'
|
||||
'Available variables are '
|
||||
+ ", ".join(self._metadata.stan_vars.keys())
|
||||
)
|
||||
|
||||
def stan_variables(self) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Return a dictionary mapping Stan program variables names
|
||||
to the corresponding numpy.ndarray containing the inferred values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanPathfinder.stan_variable
|
||||
CmdStanMCMC.stan_variables
|
||||
CmdStanMLE.stan_variables
|
||||
CmdStanVB.stan_variables
|
||||
CmdStanGQ.stan_variables
|
||||
CmdStanLaplace.stan_variables
|
||||
"""
|
||||
result = {}
|
||||
for name in self._metadata.stan_vars:
|
||||
result[name] = self.stan_variable(name)
|
||||
return result
|
||||
|
||||
def method_variables(self) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Returns a dictionary of all sampler variables, i.e., all
|
||||
output column names ending in `__`. Assumes that all variables
|
||||
are scalar variables where column name is variable name.
|
||||
Maps each column name to a numpy.ndarray (draws x chains x 1)
|
||||
containing per-draw diagnostic values.
|
||||
"""
|
||||
self._assemble_draws()
|
||||
return {
|
||||
name: var.extract_reshape(self._draws)
|
||||
for name, var in self._metadata.method_vars.items()
|
||||
}
|
||||
|
||||
def draws(self) -> np.ndarray:
|
||||
"""
|
||||
Return a numpy.ndarray containing the draws from the
|
||||
approximate posterior distribution. This is a 2-D array
|
||||
of shape (draws, parameters).
|
||||
"""
|
||||
self._assemble_draws()
|
||||
return self._draws
|
||||
|
||||
def __getattr__(self, attr: str) -> np.ndarray:
|
||||
"""Synonymous with ``fit.stan_variable(attr)"""
|
||||
if attr.startswith("_"):
|
||||
raise AttributeError(f"Unknown variable name {attr}")
|
||||
try:
|
||||
return self.stan_variable(attr)
|
||||
except ValueError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise AttributeError(*e.args)
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
# This function returns the mapping of objects to serialize with pickle.
|
||||
# See https://docs.python.org/3/library/pickle.html#object.__getstate__
|
||||
# for details. We call _assemble_draws to ensure posterior samples have
|
||||
# been loaded prior to serialization.
|
||||
self._assemble_draws()
|
||||
return self.__dict__
|
||||
|
||||
@property
|
||||
def metadata(self) -> InferenceMetadata:
|
||||
"""
|
||||
Returns object which contains CmdStan configuration as well as
|
||||
information about the names and structure of the inference method
|
||||
and model output variables.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
"""
|
||||
Names of all outputs from the sampler, comprising sampler parameters
|
||||
and all components of all model parameters, transformed parameters,
|
||||
and quantities of interest. Corresponds to Stan CSV file header row,
|
||||
with names munged to array notation, e.g. `beta[1]` not `beta.1`.
|
||||
"""
|
||||
return self._metadata.cmdstan_config['column_names'] # type: ignore
|
||||
|
||||
@property
|
||||
def is_resampled(self) -> bool:
|
||||
"""
|
||||
Returns True if the draws were resampled from several Pathfinder
|
||||
approximations, False otherwise.
|
||||
"""
|
||||
return ( # type: ignore
|
||||
self._metadata.cmdstan_config.get("num_paths", 4) > 1
|
||||
and self._metadata.cmdstan_config.get('psis_resample', 1)
|
||||
in (1, 'true')
|
||||
and self._metadata.cmdstan_config.get('calculate_lp', 1)
|
||||
in (1, 'true')
|
||||
)
|
||||
|
||||
def save_csvfiles(self, dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Move output CSV files to specified directory. If files were
|
||||
written to the temporary session directory, clean filename.
|
||||
E.g., save 'bernoulli-201912081451-1-5nm6as7u.csv' as
|
||||
'bernoulli-201912081451-1.csv'.
|
||||
|
||||
:param dir: directory path
|
||||
|
||||
See Also
|
||||
--------
|
||||
stanfit.RunSet.save_csvfiles
|
||||
cmdstanpy.from_csv
|
||||
"""
|
||||
self._runset.save_csvfiles(dir)
|
||||
307
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/runset.py
Normal file
307
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/runset.py
Normal file
@ -0,0 +1,307 @@
|
||||
"""
|
||||
Container for the information used in a generic CmdStan run,
|
||||
such as file locations
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
from cmdstanpy import _TMPDIR
|
||||
from cmdstanpy.cmdstan_args import CmdStanArgs, Method
|
||||
from cmdstanpy.utils import get_logger
|
||||
|
||||
|
||||
class RunSet:
|
||||
"""
|
||||
Encapsulates the configuration and results of a call to any CmdStan
|
||||
inference method. Records the method return code and locations of
|
||||
all console, error, and output files.
|
||||
|
||||
RunSet objects are instantiated by the CmdStanModel class inference methods
|
||||
which validate all inputs, therefore "__init__" method skips input checks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: CmdStanArgs,
|
||||
chains: int = 1,
|
||||
*,
|
||||
chain_ids: Optional[List[int]] = None,
|
||||
time_fmt: str = "%Y%m%d%H%M%S",
|
||||
one_process_per_chain: bool = True,
|
||||
) -> None:
|
||||
"""Initialize object (no input arg checks)."""
|
||||
self._args = args
|
||||
self._chains = chains
|
||||
self._one_process_per_chain = one_process_per_chain
|
||||
if one_process_per_chain:
|
||||
self._num_procs = chains
|
||||
else:
|
||||
self._num_procs = 1
|
||||
self._retcodes = [-1 for _ in range(self._num_procs)]
|
||||
self._timeout_flags = [False for _ in range(self._num_procs)]
|
||||
if chain_ids is None:
|
||||
chain_ids = [i + 1 for i in range(chains)]
|
||||
self._chain_ids = chain_ids
|
||||
|
||||
if args.output_dir is not None:
|
||||
self._output_dir = args.output_dir
|
||||
else:
|
||||
# make a per-run subdirectory of our master temp directory
|
||||
self._output_dir = tempfile.mkdtemp(
|
||||
prefix=args.model_name, dir=_TMPDIR
|
||||
)
|
||||
|
||||
# output files prefix: ``<model_name>-<YYYYMMDDHHMM>_<chain_id>``
|
||||
self._base_outfile = (
|
||||
f'{args.model_name}-{datetime.now().strftime(time_fmt)}'
|
||||
)
|
||||
# per-process outputs
|
||||
self._stdout_files = [''] * self._num_procs
|
||||
self._profile_files = [''] * self._num_procs # optional
|
||||
if one_process_per_chain:
|
||||
for i in range(chains):
|
||||
self._stdout_files[i] = self.file_path("-stdout.txt", id=i)
|
||||
if args.save_profile:
|
||||
self._profile_files[i] = self.file_path(
|
||||
".csv", extra="-profile", id=chain_ids[i]
|
||||
)
|
||||
else:
|
||||
self._stdout_files[0] = self.file_path("-stdout.txt")
|
||||
if args.save_profile:
|
||||
self._profile_files[0] = self.file_path(
|
||||
".csv", extra="-profile"
|
||||
)
|
||||
|
||||
# per-chain output files
|
||||
self._csv_files: List[str] = [''] * chains
|
||||
self._diagnostic_files = [''] * chains # optional
|
||||
|
||||
if chains == 1:
|
||||
self._csv_files[0] = self.file_path(".csv")
|
||||
if args.save_latent_dynamics:
|
||||
self._diagnostic_files[0] = self.file_path(
|
||||
".csv", extra="-diagnostic"
|
||||
)
|
||||
else:
|
||||
for i in range(chains):
|
||||
self._csv_files[i] = self.file_path(".csv", id=chain_ids[i])
|
||||
if args.save_latent_dynamics:
|
||||
self._diagnostic_files[i] = self.file_path(
|
||||
".csv", extra="-diagnostic", id=chain_ids[i]
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format(
|
||||
self._chains, self._chain_ids, self._num_procs
|
||||
)
|
||||
repr = '{}\n cmd (chain 1):\n\t{}'.format(repr, self.cmd(0))
|
||||
repr = '{}\n retcodes={}'.format(repr, self._retcodes)
|
||||
repr = f'{repr}\n per-chain output files (showing chain 1 only):'
|
||||
repr = '{}\n csv_file:\n\t{}'.format(repr, self._csv_files[0])
|
||||
if self._args.save_latent_dynamics:
|
||||
repr = '{}\n diagnostics_file:\n\t{}'.format(
|
||||
repr, self._diagnostic_files[0]
|
||||
)
|
||||
if self._args.save_profile:
|
||||
repr = '{}\n profile_file:\n\t{}'.format(
|
||||
repr, self._profile_files[0]
|
||||
)
|
||||
repr = '{}\n console_msgs (if any):\n\t{}'.format(
|
||||
repr, self._stdout_files[0]
|
||||
)
|
||||
return repr
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
"""Stan model name."""
|
||||
return self._args.model_name
|
||||
|
||||
@property
|
||||
def method(self) -> Method:
|
||||
"""CmdStan method used to generate this fit."""
|
||||
return self._args.method
|
||||
|
||||
@property
|
||||
def num_procs(self) -> int:
|
||||
"""Number of processes run."""
|
||||
return self._num_procs
|
||||
|
||||
@property
|
||||
def one_process_per_chain(self) -> bool:
|
||||
"""
|
||||
When True, for each chain, call CmdStan in its own subprocess.
|
||||
When False, use CmdStan's `num_chains` arg to run parallel chains.
|
||||
Always True if CmdStan < 2.28.
|
||||
For CmdStan 2.28 and up, `sample` method determines value.
|
||||
"""
|
||||
return self._one_process_per_chain
|
||||
|
||||
@property
|
||||
def chains(self) -> int:
|
||||
"""Number of chains."""
|
||||
return self._chains
|
||||
|
||||
@property
|
||||
def chain_ids(self) -> List[int]:
|
||||
"""Chain ids."""
|
||||
return self._chain_ids
|
||||
|
||||
def cmd(self, idx: int) -> List[str]:
|
||||
"""
|
||||
Assemble CmdStan invocation.
|
||||
When running parallel chains from single process (2.28 and up),
|
||||
specify CmdStan arg `num_chains` and leave chain idx off CSV files.
|
||||
"""
|
||||
if self._one_process_per_chain:
|
||||
return self._args.compose_command(
|
||||
idx,
|
||||
csv_file=self.csv_files[idx],
|
||||
diagnostic_file=self.diagnostic_files[idx]
|
||||
if self._args.save_latent_dynamics
|
||||
else None,
|
||||
profile_file=self.profile_files[idx]
|
||||
if self._args.save_profile
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
return self._args.compose_command(
|
||||
idx,
|
||||
csv_file=self.file_path('.csv'),
|
||||
diagnostic_file=self.file_path(".csv", extra="-diagnostic")
|
||||
if self._args.save_latent_dynamics
|
||||
else None,
|
||||
profile_file=self.file_path(".csv", extra="-profile")
|
||||
if self._args.save_profile
|
||||
else None,
|
||||
)
|
||||
|
||||
@property
|
||||
def csv_files(self) -> List[str]:
|
||||
"""List of paths to CmdStan output files."""
|
||||
return self._csv_files
|
||||
|
||||
@property
|
||||
def stdout_files(self) -> List[str]:
|
||||
"""
|
||||
List of paths to transcript of CmdStan messages sent to the console.
|
||||
Transcripts include config information, progress, and error messages.
|
||||
"""
|
||||
return self._stdout_files
|
||||
|
||||
def _check_retcodes(self) -> bool:
|
||||
"""Returns ``True`` when all chains have retcode 0."""
|
||||
for code in self._retcodes:
|
||||
if code != 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def diagnostic_files(self) -> List[str]:
|
||||
"""List of paths to CmdStan hamiltonian diagnostic files."""
|
||||
return self._diagnostic_files
|
||||
|
||||
@property
|
||||
def profile_files(self) -> List[str]:
|
||||
"""List of paths to CmdStan profiler files."""
|
||||
return self._profile_files
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def file_path(
|
||||
self, suffix: str, *, extra: str = "", id: Optional[int] = None
|
||||
) -> str:
|
||||
if id is not None:
|
||||
suffix = f"_{id}{suffix}"
|
||||
file = os.path.join(
|
||||
self._output_dir, f"{self._base_outfile}{extra}{suffix}"
|
||||
)
|
||||
return file
|
||||
|
||||
def _retcode(self, idx: int) -> int:
|
||||
"""Get retcode for process[idx]."""
|
||||
return self._retcodes[idx]
|
||||
|
||||
def _set_retcode(self, idx: int, val: int) -> None:
|
||||
"""Set retcode at process[idx] to val."""
|
||||
self._retcodes[idx] = val
|
||||
|
||||
def _set_timeout_flag(self, idx: int, val: bool) -> None:
|
||||
"""Set timeout_flag at process[idx] to val."""
|
||||
self._timeout_flags[idx] = val
|
||||
|
||||
def get_err_msgs(self) -> str:
|
||||
"""Checks console messages for each CmdStan run."""
|
||||
msgs = []
|
||||
for i in range(self._num_procs):
|
||||
if (
|
||||
os.path.exists(self._stdout_files[i])
|
||||
and os.stat(self._stdout_files[i]).st_size > 0
|
||||
):
|
||||
if self._args.method == Method.OPTIMIZE:
|
||||
msgs.append('console log output:\n')
|
||||
with open(self._stdout_files[0], 'r') as fd:
|
||||
msgs.append(fd.read())
|
||||
else:
|
||||
with open(self._stdout_files[i], 'r') as fd:
|
||||
contents = fd.read()
|
||||
# pattern matches initial "Exception" or "Error" msg
|
||||
pat = re.compile(r'^E[rx].*$', re.M)
|
||||
errors = re.findall(pat, contents)
|
||||
if len(errors) > 0:
|
||||
msgs.append('\n\t'.join(errors))
|
||||
return '\n'.join(msgs)
|
||||
|
||||
def save_csvfiles(self, dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Moves CSV files to specified directory.
|
||||
|
||||
:param dir: directory path
|
||||
|
||||
See Also
|
||||
--------
|
||||
cmdstanpy.from_csv
|
||||
"""
|
||||
if dir is None:
|
||||
dir = os.path.realpath('.')
|
||||
test_path = os.path.join(dir, str(time()))
|
||||
try:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
with open(test_path, 'w'):
|
||||
pass
|
||||
os.remove(test_path) # cleanup
|
||||
except (IOError, OSError, PermissionError) as exc:
|
||||
raise RuntimeError('Cannot save to path: {}'.format(dir)) from exc
|
||||
|
||||
for i in range(self.chains):
|
||||
if not os.path.exists(self._csv_files[i]):
|
||||
raise ValueError(
|
||||
'Cannot access CSV file {}'.format(self._csv_files[i])
|
||||
)
|
||||
|
||||
to_path = os.path.join(dir, os.path.basename(self._csv_files[i]))
|
||||
if os.path.exists(to_path):
|
||||
raise ValueError(
|
||||
'File exists, not overwriting: {}'.format(to_path)
|
||||
)
|
||||
try:
|
||||
get_logger().debug(
|
||||
'saving tmpfile: "%s" as: "%s"', self._csv_files[i], to_path
|
||||
)
|
||||
shutil.move(self._csv_files[i], to_path)
|
||||
self._csv_files[i] = to_path
|
||||
except (IOError, OSError, PermissionError) as e:
|
||||
raise ValueError(
|
||||
'Cannot save to file: {}'.format(to_path)
|
||||
) from e
|
||||
|
||||
def raise_for_timeouts(self) -> None:
|
||||
if any(self._timeout_flags):
|
||||
raise TimeoutError(
|
||||
f"{sum(self._timeout_flags)} of {self.num_procs} processes "
|
||||
"timed out"
|
||||
)
|
||||
240
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/vb.py
Normal file
240
.venv/lib/python3.12/site-packages/cmdstanpy/stanfit/vb.py
Normal file
@ -0,0 +1,240 @@
|
||||
"""Container for the results of running autodiff variational inference"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from cmdstanpy.cmdstan_args import Method
|
||||
from cmdstanpy.utils import scan_variational_csv
|
||||
from cmdstanpy.utils.logging import get_logger
|
||||
|
||||
from .metadata import InferenceMetadata
|
||||
from .runset import RunSet
|
||||
|
||||
|
||||
class CmdStanVB:
|
||||
"""
|
||||
Container for outputs from CmdStan variational run.
|
||||
Created by :meth:`CmdStanModel.variational`.
|
||||
"""
|
||||
|
||||
def __init__(self, runset: RunSet) -> None:
|
||||
"""Initialize object."""
|
||||
if not runset.method == Method.VARIATIONAL:
|
||||
raise ValueError(
|
||||
'Wrong runset method, expecting variational inference, '
|
||||
'found method {}'.format(runset.method)
|
||||
)
|
||||
self.runset = runset
|
||||
self._set_variational_attrs(runset.csv_files[0])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr = 'CmdStanVB: model={}{}'.format(
|
||||
self.runset.model, self.runset._args.method_args.compose(0, cmd=[])
|
||||
)
|
||||
repr = '{}\n csv_file:\n\t{}\n output_file:\n\t{}'.format(
|
||||
repr,
|
||||
'\n\t'.join(self.runset.csv_files),
|
||||
'\n\t'.join(self.runset.stdout_files),
|
||||
)
|
||||
# TODO - diagnostic, profiling files
|
||||
return repr
|
||||
|
||||
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
|
||||
"""Synonymous with ``fit.stan_variable(attr)"""
|
||||
if attr.startswith("_"):
|
||||
raise AttributeError(f"Unknown variable name {attr}")
|
||||
try:
|
||||
return self.stan_variable(attr)
|
||||
except ValueError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise AttributeError(*e.args)
|
||||
|
||||
def _set_variational_attrs(self, sample_csv_0: str) -> None:
|
||||
meta = scan_variational_csv(sample_csv_0)
|
||||
self._metadata = InferenceMetadata(meta)
|
||||
# these three assignments don't grant type information
|
||||
self._column_names: Tuple[str, ...] = meta['column_names']
|
||||
self._eta: float = meta['eta']
|
||||
self._variational_mean: np.ndarray = meta['variational_mean']
|
||||
self._variational_sample: np.ndarray = meta['variational_sample']
|
||||
|
||||
@property
|
||||
def columns(self) -> int:
|
||||
"""
|
||||
Total number of information items returned by sampler.
|
||||
Includes approximation information and names of model parameters
|
||||
and computed quantities.
|
||||
"""
|
||||
return len(self._column_names)
|
||||
|
||||
@property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
"""
|
||||
Names of information items returned by sampler for each draw.
|
||||
Includes approximation information and names of model parameters
|
||||
and computed quantities.
|
||||
"""
|
||||
return self._column_names
|
||||
|
||||
@property
|
||||
def eta(self) -> float:
|
||||
"""
|
||||
Step size scaling parameter 'eta'
|
||||
"""
|
||||
return self._eta
|
||||
|
||||
@property
|
||||
def variational_params_np(self) -> np.ndarray:
|
||||
"""
|
||||
Returns inferred parameter means as numpy array.
|
||||
"""
|
||||
return self._variational_mean
|
||||
|
||||
@property
|
||||
def variational_params_pd(self) -> pd.DataFrame:
|
||||
"""
|
||||
Returns inferred parameter means as pandas DataFrame.
|
||||
"""
|
||||
return pd.DataFrame([self._variational_mean], columns=self.column_names)
|
||||
|
||||
@property
|
||||
def variational_params_dict(self) -> Dict[str, np.ndarray]:
|
||||
"""Returns inferred parameter means as Dict."""
|
||||
return OrderedDict(zip(self.column_names, self._variational_mean))
|
||||
|
||||
@property
|
||||
def metadata(self) -> InferenceMetadata:
|
||||
"""
|
||||
Returns object which contains CmdStan configuration as well as
|
||||
information about the names and structure of the inference method
|
||||
and model output variables.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
def stan_variable(
|
||||
self, var: str, *, mean: Optional[bool] = None
|
||||
) -> Union[np.ndarray, float]:
|
||||
"""
|
||||
Return a numpy.ndarray which contains the estimates for the
|
||||
for the named Stan program variable where the dimensions of the
|
||||
numpy.ndarray match the shape of the Stan program variable, with
|
||||
a leading axis added for the number of draws from the variational
|
||||
approximation.
|
||||
|
||||
* If the variable is a scalar variable, the return array has shape
|
||||
( draws, ).
|
||||
* If the variable is a vector, the return array has shape
|
||||
( draws, len(vector))
|
||||
* If the variable is a matrix, the return array has shape
|
||||
( draws, size(dim 1), size(dim 2) )
|
||||
* If the variable is an array with N dimensions, the return array
|
||||
has shape ( draws, size(dim 1), ..., size(dim N))
|
||||
|
||||
This functionaltiy is also available via a shortcut using ``.`` -
|
||||
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
|
||||
|
||||
:param var: variable name
|
||||
|
||||
:param mean: if True, return the variational mean. Otherwise,
|
||||
return the variational sample. The default behavior will
|
||||
change in a future release to return the variational sample.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanVB.stan_variables
|
||||
CmdStanMCMC.stan_variable
|
||||
CmdStanMLE.stan_variable
|
||||
CmdStanPathfinder.stan_variable
|
||||
CmdStanGQ.stan_variable
|
||||
CmdStanLaplace.stan_variable
|
||||
"""
|
||||
# TODO(2.0): remove None case, make default `False`
|
||||
if mean is None:
|
||||
get_logger().warning(
|
||||
"The default behavior of CmdStanVB.stan_variable() "
|
||||
"will change in a future release to return the "
|
||||
"variational sample, rather than the mean.\n"
|
||||
"To maintain the current behavior, pass the argument "
|
||||
"mean=True"
|
||||
)
|
||||
mean = True
|
||||
if mean:
|
||||
draws = self._variational_mean
|
||||
else:
|
||||
draws = self._variational_sample
|
||||
|
||||
try:
|
||||
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
|
||||
draws
|
||||
)
|
||||
# TODO(2.0): remove
|
||||
if out.shape == () or out.shape == (1,):
|
||||
if mean:
|
||||
get_logger().warning(
|
||||
"The default behavior of "
|
||||
"CmdStanVB.stan_variable(mean=True) will change in a "
|
||||
"future release to always return a numpy.ndarray, even "
|
||||
"for scalar variables."
|
||||
)
|
||||
return out.item() # type: ignore
|
||||
return out
|
||||
except KeyError:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise ValueError(
|
||||
f'Unknown variable name: {var}\n'
|
||||
'Available variables are '
|
||||
+ ", ".join(self._metadata.stan_vars.keys())
|
||||
)
|
||||
|
||||
def stan_variables(
|
||||
self, *, mean: Optional[bool] = None
|
||||
) -> Dict[str, Union[np.ndarray, float]]:
|
||||
"""
|
||||
Return a dictionary mapping Stan program variables names
|
||||
to the corresponding numpy.ndarray containing the inferred values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
CmdStanVB.stan_variable
|
||||
CmdStanMCMC.stan_variables
|
||||
CmdStanMLE.stan_variables
|
||||
CmdStanGQ.stan_variables
|
||||
CmdStanPathfinder.stan_variables
|
||||
CmdStanLaplace.stan_variables
|
||||
"""
|
||||
result = {}
|
||||
for name in self._metadata.stan_vars:
|
||||
result[name] = self.stan_variable(name, mean=mean)
|
||||
return result
|
||||
|
||||
@property
|
||||
def variational_sample(self) -> np.ndarray:
|
||||
"""Returns the set of approximate posterior output draws."""
|
||||
return self._variational_sample
|
||||
|
||||
@property
|
||||
def variational_sample_pd(self) -> pd.DataFrame:
|
||||
"""
|
||||
Returns the set of approximate posterior output draws as
|
||||
a pandas DataFrame.
|
||||
"""
|
||||
return pd.DataFrame(self._variational_sample, columns=self.column_names)
|
||||
|
||||
def save_csvfiles(self, dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Move output CSV files to specified directory. If files were
|
||||
written to the temporary session directory, clean filename.
|
||||
E.g., save 'bernoulli-201912081451-1-5nm6as7u.csv' as
|
||||
'bernoulli-201912081451-1.csv'.
|
||||
|
||||
:param dir: directory path
|
||||
|
||||
See Also
|
||||
--------
|
||||
stanfit.RunSet.save_csvfiles
|
||||
cmdstanpy.from_csv
|
||||
"""
|
||||
self.runset.save_csvfiles(dir)
|
||||
147
.venv/lib/python3.12/site-packages/cmdstanpy/utils/__init__.py
Normal file
147
.venv/lib/python3.12/site-packages/cmdstanpy/utils/__init__.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""
|
||||
Utility functions
|
||||
"""
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
from .cmdstan import (
|
||||
EXTENSION,
|
||||
cmdstan_path,
|
||||
cmdstan_version,
|
||||
cmdstan_version_before,
|
||||
cxx_toolchain_path,
|
||||
get_latest_cmdstan,
|
||||
install_cmdstan,
|
||||
set_cmdstan_path,
|
||||
set_make_env,
|
||||
validate_cmdstan_path,
|
||||
validate_dir,
|
||||
wrap_url_progress_hook,
|
||||
)
|
||||
from .command import do_command, returncode_msg
|
||||
from .data_munging import build_xarray_data, flatten_chains
|
||||
from .filesystem import (
|
||||
SanitizedOrTmpFilePath,
|
||||
create_named_text_file,
|
||||
pushd,
|
||||
windows_short_path,
|
||||
)
|
||||
from .json import write_stan_json
|
||||
from .logging import get_logger
|
||||
from .stancsv import (
|
||||
check_sampler_csv,
|
||||
parse_rdump_value,
|
||||
read_metric,
|
||||
rload,
|
||||
scan_column_names,
|
||||
scan_config,
|
||||
scan_hmc_params,
|
||||
scan_optimize_csv,
|
||||
scan_sampler_csv,
|
||||
scan_sampling_iters,
|
||||
scan_variational_csv,
|
||||
scan_warmup_iters,
|
||||
)
|
||||
|
||||
|
||||
def show_versions(output: bool = True) -> str:
|
||||
"""Prints out system and dependency information for debugging"""
|
||||
|
||||
import importlib
|
||||
import locale
|
||||
import struct
|
||||
|
||||
deps_info = []
|
||||
try:
|
||||
(sysname, _, release, _, machine, processor) = platform.uname()
|
||||
deps_info.extend(
|
||||
[
|
||||
("python", sys.version),
|
||||
("python-bits", struct.calcsize("P") * 8),
|
||||
("OS", f"{sysname}"),
|
||||
("OS-release", f"{release}"),
|
||||
("machine", f"{machine}"),
|
||||
("processor", f"{processor}"),
|
||||
("byteorder", f"{sys.byteorder}"),
|
||||
("LC_ALL", f'{os.environ.get("LC_ALL", "None")}'),
|
||||
("LANG", f'{os.environ.get("LANG", "None")}'),
|
||||
("LOCALE", f"{locale.getlocale()}"),
|
||||
]
|
||||
)
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
deps_info.append(('cmdstan_folder', cmdstan_path()))
|
||||
deps_info.append(('cmdstan', str(cmdstan_version())))
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
deps_info.append(('cmdstan', 'NOT FOUND'))
|
||||
|
||||
deps = ['cmdstanpy', 'pandas', 'xarray', 'tqdm', 'numpy']
|
||||
for module in deps:
|
||||
try:
|
||||
if module in sys.modules:
|
||||
mod = sys.modules[module]
|
||||
else:
|
||||
mod = importlib.import_module(module)
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
deps_info.append((module, None))
|
||||
else:
|
||||
try:
|
||||
ver = mod.__version__ # type: ignore
|
||||
deps_info.append((module, ver))
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
deps_info.append((module, "installed"))
|
||||
|
||||
out = 'INSTALLED VERSIONS\n---------------------\n'
|
||||
for k, info in deps_info:
|
||||
out += f'{k}: {info}\n'
|
||||
if output:
|
||||
print(out)
|
||||
return " "
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
'EXTENSION',
|
||||
'SanitizedOrTmpFilePath',
|
||||
'build_xarray_data',
|
||||
'check_sampler_csv',
|
||||
'cmdstan_path',
|
||||
'cmdstan_version',
|
||||
'cmdstan_version_before',
|
||||
'create_named_text_file',
|
||||
'cxx_toolchain_path',
|
||||
'do_command',
|
||||
'flatten_chains',
|
||||
'get_latest_cmdstan',
|
||||
'get_logger',
|
||||
'install_cmdstan',
|
||||
'parse_rdump_value',
|
||||
'pushd',
|
||||
'read_metric',
|
||||
'returncode_msg',
|
||||
'rload',
|
||||
'scan_column_names',
|
||||
'scan_config',
|
||||
'scan_hmc_params',
|
||||
'scan_optimize_csv',
|
||||
'scan_sampler_csv',
|
||||
'scan_sampling_iters',
|
||||
'scan_variational_csv',
|
||||
'scan_warmup_iters',
|
||||
'set_cmdstan_path',
|
||||
'set_make_env',
|
||||
'show_versions',
|
||||
'validate_cmdstan_path',
|
||||
'validate_dir',
|
||||
'windows_short_path',
|
||||
'wrap_url_progress_hook',
|
||||
'write_stan_json',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
554
.venv/lib/python3.12/site-packages/cmdstanpy/utils/cmdstan.py
Normal file
554
.venv/lib/python3.12/site-packages/cmdstanpy/utils/cmdstan.py
Normal file
@ -0,0 +1,554 @@
|
||||
"""
|
||||
Utilities for finding and installing CmdStan
|
||||
"""
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from cmdstanpy import _DOT_CMDSTAN
|
||||
|
||||
from .. import progress as progbar
|
||||
from .logging import get_logger
|
||||
|
||||
EXTENSION = '.exe' if platform.system() == 'Windows' else ''
|
||||
|
||||
|
||||
def determine_linux_arch() -> str:
|
||||
machine = platform.machine()
|
||||
arch = ""
|
||||
if machine == "aarch64":
|
||||
arch = "arm64"
|
||||
elif machine == "armv7l":
|
||||
# Telling armel and armhf apart is nontrivial
|
||||
# c.f. https://forums.raspberrypi.com/viewtopic.php?t=20873
|
||||
readelf = subprocess.run(
|
||||
["readelf", "-A", "/proc/self/exe"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
if "Tag_ABI_VFP_args" in readelf.stdout:
|
||||
arch = "armel"
|
||||
else:
|
||||
arch = "armhf"
|
||||
elif machine == "mips64":
|
||||
arch = "mips64el"
|
||||
elif machine == "ppc64el" or machine == "ppc64le":
|
||||
arch = "ppc64el"
|
||||
elif machine == "s390x":
|
||||
arch = "s390x"
|
||||
return arch
|
||||
|
||||
|
||||
def get_download_url(version: str) -> str:
|
||||
arch = os.environ.get("CMDSTAN_ARCH", "")
|
||||
if not arch and platform.system() == "Linux":
|
||||
arch = determine_linux_arch()
|
||||
|
||||
if arch and arch.lower() != "false":
|
||||
url_end = f'v{version}/cmdstan-{version}-linux-{arch}.tar.gz'
|
||||
else:
|
||||
url_end = f'v{version}/cmdstan-{version}.tar.gz'
|
||||
|
||||
return f'https://github.com/stan-dev/cmdstan/releases/download/{url_end}'
|
||||
|
||||
|
||||
def validate_dir(install_dir: str) -> None:
|
||||
"""Check that specified install directory exists, can write."""
|
||||
if not os.path.exists(install_dir):
|
||||
try:
|
||||
os.makedirs(install_dir)
|
||||
except (IOError, OSError, PermissionError) as e:
|
||||
raise ValueError(
|
||||
'Cannot create directory: {}'.format(install_dir)
|
||||
) from e
|
||||
else:
|
||||
if not os.path.isdir(install_dir):
|
||||
raise ValueError(
|
||||
'File exists, should be a directory: {}'.format(install_dir)
|
||||
)
|
||||
try:
|
||||
with open('tmp_test_w', 'w'):
|
||||
pass
|
||||
os.remove('tmp_test_w') # cleanup
|
||||
except OSError as e:
|
||||
raise ValueError(
|
||||
'Cannot write files to directory {}'.format(install_dir)
|
||||
) from e
|
||||
|
||||
|
||||
def get_latest_cmdstan(cmdstan_dir: str) -> Optional[str]:
|
||||
"""
|
||||
Given a valid directory path, find all installed CmdStan versions
|
||||
and return highest (i.e., latest) version number.
|
||||
|
||||
Assumes directory consists of CmdStan releases, created by
|
||||
function `install_cmdstan`, and therefore dirnames have format
|
||||
"cmdstan-<maj>.<min>.<patch>" or "cmdstan-<maj>.<min>.<patch>-rc<num>",
|
||||
which is CmdStan release practice as of v 2.24.
|
||||
"""
|
||||
versions = [
|
||||
name[8:]
|
||||
for name in os.listdir(cmdstan_dir)
|
||||
if os.path.isdir(os.path.join(cmdstan_dir, name))
|
||||
and name.startswith('cmdstan-')
|
||||
]
|
||||
if len(versions) == 0:
|
||||
return None
|
||||
if len(versions) == 1:
|
||||
return 'cmdstan-' + versions[0]
|
||||
# we can only compare numeric versions
|
||||
versions = [v for v in versions if v[0].isdigit() and v.count('.') == 2]
|
||||
# munge rc for sort, e.g. 2.25.0-rc1 -> 2.25.-99
|
||||
for i in range(len(versions)): # # pylint: disable=C0200
|
||||
if '-rc' in versions[i]:
|
||||
comps = versions[i].split('-rc')
|
||||
mmp = comps[0].split('.')
|
||||
rc_num = comps[1]
|
||||
patch = str(int(rc_num) - 100)
|
||||
versions[i] = '.'.join([mmp[0], mmp[1], patch])
|
||||
|
||||
versions.sort(key=lambda s: list(map(int, s.split('.'))))
|
||||
latest = versions[len(versions) - 1]
|
||||
|
||||
# unmunge as needed
|
||||
mmp = latest.split('.')
|
||||
if int(mmp[2]) < 0:
|
||||
rc_num = str(int(mmp[2]) + 100)
|
||||
mmp[2] = "0-rc" + rc_num
|
||||
latest = '.'.join(mmp)
|
||||
|
||||
return 'cmdstan-' + latest
|
||||
|
||||
|
||||
def validate_cmdstan_path(path: str) -> None:
|
||||
"""
|
||||
Validate that CmdStan directory exists and binaries have been built.
|
||||
Throws exception if specified path is invalid.
|
||||
"""
|
||||
if not os.path.isdir(path):
|
||||
raise ValueError(f'No CmdStan directory, path {path} does not exist.')
|
||||
if not os.path.exists(os.path.join(path, 'bin', 'stanc' + EXTENSION)):
|
||||
raise ValueError(
|
||||
f'CmdStan installataion missing binaries in {path}/bin. '
|
||||
'Re-install cmdstan by running command "install_cmdstan '
|
||||
'--overwrite", or Python code "import cmdstanpy; '
|
||||
'cmdstanpy.install_cmdstan(overwrite=True)"'
|
||||
)
|
||||
|
||||
|
||||
def set_cmdstan_path(path: str) -> None:
|
||||
"""
|
||||
Validate, then set CmdStan directory path.
|
||||
"""
|
||||
validate_cmdstan_path(path)
|
||||
os.environ['CMDSTAN'] = path
|
||||
|
||||
|
||||
def set_make_env(make: str) -> None:
|
||||
"""
|
||||
set MAKE environmental variable.
|
||||
"""
|
||||
os.environ['MAKE'] = make
|
||||
|
||||
|
||||
def cmdstan_path() -> str:
|
||||
"""
|
||||
Validate, then return CmdStan directory path.
|
||||
"""
|
||||
cmdstan = ''
|
||||
if 'CMDSTAN' in os.environ and len(os.environ['CMDSTAN']) > 0:
|
||||
cmdstan = os.environ['CMDSTAN']
|
||||
else:
|
||||
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
|
||||
if not os.path.exists(cmdstan_dir):
|
||||
raise ValueError(
|
||||
'No CmdStan installation found, run command "install_cmdstan"'
|
||||
'or (re)activate your conda environment!'
|
||||
)
|
||||
latest_cmdstan = get_latest_cmdstan(cmdstan_dir)
|
||||
if latest_cmdstan is None:
|
||||
raise ValueError(
|
||||
'No CmdStan installation found, run command "install_cmdstan"'
|
||||
'or (re)activate your conda environment!'
|
||||
)
|
||||
cmdstan = os.path.join(cmdstan_dir, latest_cmdstan)
|
||||
os.environ['CMDSTAN'] = cmdstan
|
||||
validate_cmdstan_path(cmdstan)
|
||||
return os.path.normpath(cmdstan)
|
||||
|
||||
|
||||
def cmdstan_version() -> Optional[Tuple[int, ...]]:
|
||||
"""
|
||||
Parses version string out of CmdStan makefile variable CMDSTAN_VERSION,
|
||||
returns Tuple(Major, minor).
|
||||
|
||||
If CmdStan installation is not found or cannot parse version from makefile
|
||||
logs warning and returns None. Lenient behavoir required for CI tests,
|
||||
per comment:
|
||||
https://github.com/stan-dev/cmdstanpy/pull/321#issuecomment-733817554
|
||||
"""
|
||||
try:
|
||||
makefile = os.path.join(cmdstan_path(), 'makefile')
|
||||
except ValueError as e:
|
||||
get_logger().info('No CmdStan installation found.')
|
||||
get_logger().debug("%s", e)
|
||||
return None
|
||||
|
||||
if not os.path.exists(makefile):
|
||||
get_logger().info(
|
||||
'CmdStan installation %s missing makefile, cannot get version.',
|
||||
cmdstan_path(),
|
||||
)
|
||||
return None
|
||||
|
||||
with open(makefile, 'r') as fd:
|
||||
contents = fd.read()
|
||||
|
||||
start_idx = contents.find('CMDSTAN_VERSION := ')
|
||||
if start_idx < 0:
|
||||
get_logger().info(
|
||||
'Cannot parse version from makefile: %s.',
|
||||
makefile,
|
||||
)
|
||||
return None
|
||||
|
||||
start_idx += len('CMDSTAN_VERSION := ')
|
||||
end_idx = contents.find('\n', start_idx)
|
||||
|
||||
version = contents[start_idx:end_idx]
|
||||
splits = version.split('.')
|
||||
if len(splits) != 3:
|
||||
get_logger().info(
|
||||
'Cannot parse version, expected "<major>.<minor>.<patch>", '
|
||||
'found: "%s".',
|
||||
version,
|
||||
)
|
||||
return None
|
||||
return tuple(int(x) for x in splits[0:2])
|
||||
|
||||
|
||||
def cmdstan_version_before(
|
||||
major: int, minor: int, info: Optional[Dict[str, str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check that CmdStan version is less than Major.minor version.
|
||||
|
||||
:param major: Major version number
|
||||
:param minor: Minor version number
|
||||
|
||||
:return: True if version at or above major.minor, else False.
|
||||
"""
|
||||
cur_version = None
|
||||
if info is None or 'stan_version_major' not in info:
|
||||
cur_version = cmdstan_version()
|
||||
else:
|
||||
cur_version = (
|
||||
int(info['stan_version_major']),
|
||||
int(info['stan_version_minor']),
|
||||
)
|
||||
if cur_version is None:
|
||||
get_logger().info(
|
||||
'Cannot determine whether version is before %d.%d.', major, minor
|
||||
)
|
||||
return False
|
||||
if cur_version[0] < major or (
|
||||
cur_version[0] == major and cur_version[1] < minor
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cxx_toolchain_path(
|
||||
version: Optional[str] = None, install_dir: Optional[str] = None
|
||||
) -> Tuple[str, ...]:
|
||||
"""
|
||||
Validate, then activate C++ toolchain directory path.
|
||||
"""
|
||||
if platform.system() != 'Windows':
|
||||
raise RuntimeError(
|
||||
'Functionality is currently only supported on Windows'
|
||||
)
|
||||
if version is not None and not isinstance(version, str):
|
||||
raise TypeError('Format version number as a string')
|
||||
logger = get_logger()
|
||||
if 'CMDSTAN_TOOLCHAIN' in os.environ:
|
||||
toolchain_root = os.environ['CMDSTAN_TOOLCHAIN']
|
||||
if os.path.exists(os.path.join(toolchain_root, 'mingw64')):
|
||||
compiler_path = os.path.join(
|
||||
toolchain_root,
|
||||
'mingw64' if (sys.maxsize > 2**32) else 'mingw32',
|
||||
'bin',
|
||||
)
|
||||
if os.path.exists(compiler_path):
|
||||
tool_path = os.path.join(toolchain_root, 'usr', 'bin')
|
||||
if not os.path.exists(tool_path):
|
||||
tool_path = ''
|
||||
compiler_path = ''
|
||||
logger.warning(
|
||||
'Found invalid installion for RTools40 on %s',
|
||||
toolchain_root,
|
||||
)
|
||||
toolchain_root = ''
|
||||
else:
|
||||
compiler_path = ''
|
||||
logger.warning(
|
||||
'Found invalid installion for RTools40 on %s',
|
||||
toolchain_root,
|
||||
)
|
||||
toolchain_root = ''
|
||||
|
||||
elif os.path.exists(os.path.join(toolchain_root, 'mingw_64')):
|
||||
compiler_path = os.path.join(
|
||||
toolchain_root,
|
||||
'mingw_64' if (sys.maxsize > 2**32) else 'mingw_32',
|
||||
'bin',
|
||||
)
|
||||
if os.path.exists(compiler_path):
|
||||
tool_path = os.path.join(toolchain_root, 'bin')
|
||||
if not os.path.exists(tool_path):
|
||||
tool_path = ''
|
||||
compiler_path = ''
|
||||
logger.warning(
|
||||
'Found invalid installion for RTools35 on %s',
|
||||
toolchain_root,
|
||||
)
|
||||
toolchain_root = ''
|
||||
else:
|
||||
compiler_path = ''
|
||||
logger.warning(
|
||||
'Found invalid installion for RTools35 on %s',
|
||||
toolchain_root,
|
||||
)
|
||||
toolchain_root = ''
|
||||
else:
|
||||
rtools40_home = os.environ.get('RTOOLS40_HOME')
|
||||
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
|
||||
for toolchain_root in (
|
||||
([rtools40_home] if rtools40_home is not None else [])
|
||||
+ (
|
||||
[
|
||||
os.path.join(install_dir, 'RTools40'),
|
||||
os.path.join(install_dir, 'RTools35'),
|
||||
os.path.join(install_dir, 'RTools30'),
|
||||
os.path.join(install_dir, 'RTools'),
|
||||
]
|
||||
if install_dir is not None
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
os.path.join(cmdstan_dir, 'RTools40'),
|
||||
os.path.join(os.path.abspath("/"), "RTools40"),
|
||||
os.path.join(cmdstan_dir, 'RTools35'),
|
||||
os.path.join(os.path.abspath("/"), "RTools35"),
|
||||
os.path.join(cmdstan_dir, 'RTools'),
|
||||
os.path.join(os.path.abspath("/"), "RTools"),
|
||||
os.path.join(os.path.abspath("/"), "RBuildTools"),
|
||||
]
|
||||
):
|
||||
compiler_path = ''
|
||||
tool_path = ''
|
||||
|
||||
if os.path.exists(toolchain_root):
|
||||
if version not in ('35', '3.5', '3'):
|
||||
compiler_path = os.path.join(
|
||||
toolchain_root,
|
||||
'mingw64' if (sys.maxsize > 2**32) else 'mingw32',
|
||||
'bin',
|
||||
)
|
||||
if os.path.exists(compiler_path):
|
||||
tool_path = os.path.join(toolchain_root, 'usr', 'bin')
|
||||
if not os.path.exists(tool_path):
|
||||
tool_path = ''
|
||||
compiler_path = ''
|
||||
logger.warning(
|
||||
'Found invalid installation for RTools40 on %s',
|
||||
toolchain_root,
|
||||
)
|
||||
toolchain_root = ''
|
||||
else:
|
||||
break
|
||||
else:
|
||||
compiler_path = ''
|
||||
logger.warning(
|
||||
'Found invalid installation for RTools40 on %s',
|
||||
toolchain_root,
|
||||
)
|
||||
toolchain_root = ''
|
||||
else:
|
||||
compiler_path = os.path.join(
|
||||
toolchain_root,
|
||||
'mingw_64' if (sys.maxsize > 2**32) else 'mingw_32',
|
||||
'bin',
|
||||
)
|
||||
if os.path.exists(compiler_path):
|
||||
tool_path = os.path.join(toolchain_root, 'bin')
|
||||
if not os.path.exists(tool_path):
|
||||
tool_path = ''
|
||||
compiler_path = ''
|
||||
logger.warning(
|
||||
'Found invalid installation for RTools35 on %s',
|
||||
toolchain_root,
|
||||
)
|
||||
toolchain_root = ''
|
||||
else:
|
||||
break
|
||||
else:
|
||||
compiler_path = ''
|
||||
logger.warning(
|
||||
'Found invalid installation for RTools35 on %s',
|
||||
toolchain_root,
|
||||
)
|
||||
toolchain_root = ''
|
||||
else:
|
||||
toolchain_root = ''
|
||||
|
||||
if not toolchain_root:
|
||||
raise ValueError(
|
||||
'no RTools toolchain installation found, '
|
||||
'run command line script '
|
||||
'"python -m cmdstanpy.install_cxx_toolchain"'
|
||||
)
|
||||
logger.info('Add C++ toolchain to $PATH: %s', toolchain_root)
|
||||
os.environ['PATH'] = ';'.join(
|
||||
list(
|
||||
OrderedDict.fromkeys(
|
||||
[compiler_path, tool_path] + os.getenv('PATH', '').split(';')
|
||||
)
|
||||
)
|
||||
)
|
||||
return compiler_path, tool_path
|
||||
|
||||
|
||||
def install_cmdstan(
|
||||
version: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
overwrite: bool = False,
|
||||
compiler: bool = False,
|
||||
progress: bool = False,
|
||||
verbose: bool = False,
|
||||
cores: int = 1,
|
||||
*,
|
||||
interactive: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Download and install a CmdStan release from GitHub. Downloads the release
|
||||
tar.gz file to temporary storage. Retries GitHub requests in order
|
||||
to allow for transient network outages. Builds CmdStan executables
|
||||
and tests the compiler by building example model ``bernoulli.stan``.
|
||||
|
||||
:param version: CmdStan version string, e.g. "2.29.2".
|
||||
Defaults to latest CmdStan release.
|
||||
If ``git`` is installed, a git tag or branch of stan-dev/cmdstan
|
||||
can be specified, e.g. "git:develop".
|
||||
|
||||
:param dir: Path to install directory. Defaults to hidden directory
|
||||
``$HOME/.cmdstan``.
|
||||
If no directory is specified and the above directory does not
|
||||
exist, directory ``$HOME/.cmdstan`` will be created and populated.
|
||||
|
||||
:param overwrite: Boolean value; when ``True``, will overwrite and
|
||||
rebuild an existing CmdStan installation. Default is ``False``.
|
||||
|
||||
:param compiler: Boolean value; when ``True`` on WINDOWS ONLY, use the
|
||||
C++ compiler from the ``install_cxx_toolchain`` command or install
|
||||
one if none is found.
|
||||
|
||||
:param progress: Boolean value; when ``True``, show a progress bar for
|
||||
downloading and unpacking CmdStan. Default is ``False``.
|
||||
|
||||
:param verbose: Boolean value; when ``True``, show console output from all
|
||||
intallation steps, i.e., download, build, and test CmdStan release.
|
||||
Default is ``False``.
|
||||
:param cores: Integer, number of cores to use in the ``make`` command.
|
||||
Default is 1 core.
|
||||
|
||||
:param interactive: Boolean value; if true, ignore all other arguments
|
||||
to this function and run in an interactive mode, prompting the user
|
||||
to provide the other information manually through the standard input.
|
||||
|
||||
This flag should only be used in interactive environments,
|
||||
e.g. on the command line.
|
||||
|
||||
:return: Boolean value; ``True`` for success.
|
||||
"""
|
||||
logger = get_logger()
|
||||
try:
|
||||
from ..install_cmdstan import (
|
||||
InstallationSettings,
|
||||
InteractiveSettings,
|
||||
run_install,
|
||||
)
|
||||
|
||||
args: Union[InstallationSettings, InteractiveSettings]
|
||||
|
||||
if interactive:
|
||||
if any(
|
||||
[
|
||||
version,
|
||||
dir,
|
||||
overwrite,
|
||||
compiler,
|
||||
progress,
|
||||
verbose,
|
||||
cores != 1,
|
||||
]
|
||||
):
|
||||
logger.warning(
|
||||
"Interactive installation requested but other arguments"
|
||||
" were used.\n\tThese values will be ignored!"
|
||||
)
|
||||
args = InteractiveSettings()
|
||||
else:
|
||||
args = InstallationSettings(
|
||||
version=version,
|
||||
overwrite=overwrite,
|
||||
verbose=verbose,
|
||||
compiler=compiler,
|
||||
progress=progress,
|
||||
dir=dir,
|
||||
cores=cores,
|
||||
)
|
||||
run_install(args)
|
||||
# pylint: disable=broad-except
|
||||
except Exception as e:
|
||||
logger.warning('CmdStan installation failed.\n%s', str(e))
|
||||
return False
|
||||
|
||||
if 'git:' in args.version:
|
||||
folder = f"cmdstan-{args.version.replace(':', '-').replace('/', '_')}"
|
||||
else:
|
||||
folder = f"cmdstan-{args.version}"
|
||||
set_cmdstan_path(os.path.join(args.dir, folder))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@progbar.wrap_callback
|
||||
def wrap_url_progress_hook() -> Optional[Callable[[int, int, int], None]]:
|
||||
"""Sets up tqdm callback for url downloads."""
|
||||
pbar: tqdm = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
colour='blue',
|
||||
leave=False,
|
||||
)
|
||||
|
||||
def download_progress_hook(
|
||||
count: int, block_size: int, total_size: int
|
||||
) -> None:
|
||||
if pbar.total is None:
|
||||
pbar.total = total_size
|
||||
pbar.reset()
|
||||
downloaded_size = count * block_size
|
||||
pbar.update(downloaded_size - pbar.n)
|
||||
if pbar.n >= total_size:
|
||||
pbar.close()
|
||||
|
||||
return download_progress_hook
|
||||
@ -0,0 +1,94 @@
|
||||
"""
|
||||
Run commands and handle returncodes
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Callable, List, Optional, TextIO
|
||||
|
||||
from .filesystem import pushd
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
def do_command(
|
||||
cmd: List[str],
|
||||
cwd: Optional[str] = None,
|
||||
*,
|
||||
fd_out: Optional[TextIO] = sys.stdout,
|
||||
pbar: Optional[Callable[[str], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run command as subprocess, polls process output pipes and
|
||||
either streams outputs to supplied output stream or sends
|
||||
each line (stripped) to the supplied progress bar callback hook.
|
||||
|
||||
Raises ``RuntimeError`` on non-zero return code or execption ``OSError``.
|
||||
|
||||
:param cmd: command and args.
|
||||
:param cwd: directory in which to run command, if unspecified,
|
||||
run command in the current working directory.
|
||||
:param fd_out: when supplied, streams to this output stream,
|
||||
else writes to sys.stdout.
|
||||
:param pbar: optional callback hook to tqdm, which takes
|
||||
single ``str`` arguent, see:
|
||||
https://github.com/tqdm/tqdm#hooks-and-callbacks.
|
||||
|
||||
"""
|
||||
get_logger().debug('cmd: %s\ncwd: %s', ' '.join(cmd), cwd)
|
||||
try:
|
||||
# NB: Using this rather than cwd arg to Popen due to windows behavior
|
||||
with pushd(cwd if cwd is not None else '.'):
|
||||
# TODO: replace with subprocess.run in later Python versions?
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
bufsize=1,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # avoid buffer overflow
|
||||
env=os.environ,
|
||||
universal_newlines=True,
|
||||
)
|
||||
while proc.poll() is None:
|
||||
if proc.stdout is not None:
|
||||
line = proc.stdout.readline()
|
||||
if fd_out is not None:
|
||||
fd_out.write(line)
|
||||
if pbar is not None:
|
||||
pbar(line.strip())
|
||||
|
||||
stdout, _ = proc.communicate()
|
||||
if stdout:
|
||||
if len(stdout) > 0:
|
||||
if fd_out is not None:
|
||||
fd_out.write(stdout)
|
||||
if pbar is not None:
|
||||
pbar(stdout.strip())
|
||||
|
||||
if proc.returncode != 0: # throw RuntimeError + msg
|
||||
serror = ''
|
||||
try:
|
||||
serror = os.strerror(proc.returncode)
|
||||
except (ArithmeticError, ValueError):
|
||||
pass
|
||||
msg = 'Command {}\n\t{} {}'.format(
|
||||
cmd, returncode_msg(proc.returncode), serror
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
except OSError as e:
|
||||
msg = 'Command: {}\nfailed with error {}\n'.format(cmd, str(e))
|
||||
raise RuntimeError(msg) from e
|
||||
|
||||
|
||||
def returncode_msg(retcode: int) -> str:
|
||||
"""interpret retcode"""
|
||||
if retcode < 0:
|
||||
sig = -1 * retcode
|
||||
return f'terminated by signal {sig}'
|
||||
if retcode <= 125:
|
||||
return 'error during processing'
|
||||
if retcode == 126: # shouldn't happen
|
||||
return ''
|
||||
if retcode == 127:
|
||||
return 'program not found'
|
||||
sig = retcode - 128
|
||||
return f'terminated by signal {sig}'
|
||||
@ -0,0 +1,44 @@
|
||||
"""
|
||||
Common functions for reshaping numpy arrays
|
||||
"""
|
||||
from typing import Hashable, MutableMapping, Tuple
|
||||
|
||||
import numpy as np
|
||||
import stanio
|
||||
|
||||
|
||||
def flatten_chains(draws_array: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Flatten a 3D array of draws X chains X variable into 2D array
|
||||
where all chains are concatenated into a single column.
|
||||
|
||||
:param draws_array: 3D array of draws
|
||||
"""
|
||||
if len(draws_array.shape) != 3:
|
||||
raise ValueError(
|
||||
'Expecting 3D array, found array with {} dims'.format(
|
||||
len(draws_array.shape)
|
||||
)
|
||||
)
|
||||
|
||||
num_rows = draws_array.shape[0] * draws_array.shape[1]
|
||||
num_cols = draws_array.shape[2]
|
||||
return draws_array.reshape((num_rows, num_cols), order='F')
|
||||
|
||||
|
||||
def build_xarray_data(
|
||||
data: MutableMapping[Hashable, Tuple[Tuple[str, ...], np.ndarray]],
|
||||
var: stanio.Variable,
|
||||
drawset: np.ndarray,
|
||||
) -> None:
|
||||
"""
|
||||
Adds Stan variable name, labels, and values to a dictionary
|
||||
that will be used to construct an xarray DataSet.
|
||||
"""
|
||||
var_dims: Tuple[str, ...] = ('draw', 'chain')
|
||||
var_dims += tuple(f"{var.name}_dim_{i}" for i in range(len(var.dimensions)))
|
||||
|
||||
data[var.name] = (
|
||||
var_dims,
|
||||
var.extract_reshape(drawset),
|
||||
)
|
||||
236
.venv/lib/python3.12/site-packages/cmdstanpy/utils/filesystem.py
Normal file
236
.venv/lib/python3.12/site-packages/cmdstanpy/utils/filesystem.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""
|
||||
Utilities for interacting with the filesystem on multiple platforms
|
||||
"""
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Any, Iterator, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
from cmdstanpy import _TMPDIR
|
||||
|
||||
from .json import write_stan_json
|
||||
from .logging import get_logger
|
||||
|
||||
EXTENSION = '.exe' if platform.system() == 'Windows' else ''
|
||||
|
||||
|
||||
def windows_short_path(path: str) -> str:
|
||||
"""
|
||||
Gets the short path name of a given long path.
|
||||
http://stackoverflow.com/a/23598461/200291
|
||||
|
||||
On non-Windows platforms, returns the path
|
||||
|
||||
If (base)path does not exist, function raises RuntimeError
|
||||
"""
|
||||
if platform.system() != 'Windows':
|
||||
return path
|
||||
|
||||
if os.path.isfile(path) or (
|
||||
not os.path.isdir(path) and os.path.splitext(path)[1] != ''
|
||||
):
|
||||
base_path, file_name = os.path.split(path)
|
||||
else:
|
||||
base_path, file_name = path, ''
|
||||
|
||||
if not os.path.exists(base_path):
|
||||
raise RuntimeError(
|
||||
'Windows short path function needs a valid directory. '
|
||||
'Base directory does not exist: "{}"'.format(base_path)
|
||||
)
|
||||
|
||||
import ctypes
|
||||
from ctypes import wintypes
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_GetShortPathNameW = (
|
||||
ctypes.windll.kernel32.GetShortPathNameW # type: ignore
|
||||
)
|
||||
|
||||
_GetShortPathNameW.argtypes = [
|
||||
wintypes.LPCWSTR,
|
||||
wintypes.LPWSTR,
|
||||
wintypes.DWORD,
|
||||
]
|
||||
_GetShortPathNameW.restype = wintypes.DWORD
|
||||
|
||||
output_buf_size = 0
|
||||
while True:
|
||||
output_buf = ctypes.create_unicode_buffer(output_buf_size)
|
||||
needed = _GetShortPathNameW(base_path, output_buf, output_buf_size)
|
||||
if output_buf_size >= needed:
|
||||
short_base_path = output_buf.value
|
||||
break
|
||||
else:
|
||||
output_buf_size = needed
|
||||
|
||||
short_path = (
|
||||
os.path.join(short_base_path, file_name)
|
||||
if file_name
|
||||
else short_base_path
|
||||
)
|
||||
return short_path
|
||||
|
||||
|
||||
def create_named_text_file(
|
||||
dir: str, prefix: str, suffix: str, name_only: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Create a named unique file, return filename.
|
||||
Flag 'name_only' will create then delete the tmp file;
|
||||
this lets us create filename args for commands which
|
||||
disallow overwriting existing files (e.g., 'stansummary').
|
||||
"""
|
||||
fd = tempfile.NamedTemporaryFile(
|
||||
mode='w+', prefix=prefix, suffix=suffix, dir=dir, delete=name_only
|
||||
)
|
||||
path = fd.name
|
||||
fd.close()
|
||||
return path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def pushd(new_dir: str) -> Iterator[None]:
|
||||
"""Acts like pushd/popd."""
|
||||
previous_dir = os.getcwd()
|
||||
os.chdir(new_dir)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(previous_dir)
|
||||
|
||||
|
||||
def _temp_single_json(
|
||||
data: Union[str, os.PathLike, Mapping[str, Any], None]
|
||||
) -> Iterator[Optional[str]]:
|
||||
"""Context manager for json files."""
|
||||
if data is None:
|
||||
yield None
|
||||
return
|
||||
if isinstance(data, (str, os.PathLike)):
|
||||
yield str(data)
|
||||
return
|
||||
|
||||
data_file = create_named_text_file(dir=_TMPDIR, prefix='', suffix='.json')
|
||||
get_logger().debug('input tempfile: %s', data_file)
|
||||
write_stan_json(data_file, data)
|
||||
try:
|
||||
yield data_file
|
||||
finally:
|
||||
with contextlib.suppress(PermissionError):
|
||||
os.remove(data_file)
|
||||
|
||||
|
||||
temp_single_json = contextlib.contextmanager(_temp_single_json)
|
||||
|
||||
|
||||
def _temp_multiinput(
|
||||
input: Union[str, os.PathLike, Mapping[str, Any], List[Any], None],
|
||||
base: int = 1,
|
||||
) -> Iterator[Optional[str]]:
|
||||
if isinstance(input, list):
|
||||
# most complicated case: list of inits
|
||||
# for multiple chains, we need to create multiple files
|
||||
# which look like somename_{i}.json and then pass somename.json
|
||||
# to CmdStan
|
||||
|
||||
mother_file = create_named_text_file(
|
||||
dir=_TMPDIR, prefix='', suffix='.json', name_only=True
|
||||
)
|
||||
new_files = [
|
||||
os.path.splitext(mother_file)[0] + f'_{i+base}.json'
|
||||
for i in range(len(input))
|
||||
]
|
||||
for init, file in zip(input, new_files):
|
||||
if isinstance(init, dict):
|
||||
write_stan_json(file, init)
|
||||
elif isinstance(init, str):
|
||||
shutil.copy(init, file)
|
||||
else:
|
||||
raise ValueError(
|
||||
'A list of inits must contain dicts or strings, not'
|
||||
+ str(type(init))
|
||||
)
|
||||
try:
|
||||
yield mother_file
|
||||
finally:
|
||||
for file in new_files:
|
||||
with contextlib.suppress(PermissionError):
|
||||
os.remove(file)
|
||||
else:
|
||||
yield from _temp_single_json(input)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_inits(
|
||||
inits: Union[
|
||||
str, os.PathLike, Mapping[str, Any], float, int, List[Any], None
|
||||
],
|
||||
*,
|
||||
allow_multiple: bool = True,
|
||||
id: int = 1,
|
||||
) -> Iterator[Union[str, float, int, None]]:
|
||||
if isinstance(inits, (float, int)):
|
||||
yield inits
|
||||
return
|
||||
if allow_multiple:
|
||||
yield from _temp_multiinput(inits, base=id)
|
||||
else:
|
||||
if isinstance(inits, list):
|
||||
raise ValueError('Expected single initialization, got list')
|
||||
yield from _temp_single_json(inits)
|
||||
|
||||
|
||||
class SanitizedOrTmpFilePath:
|
||||
"""
|
||||
Context manager for tmpfiles, handles special characters in filepath.
|
||||
"""
|
||||
|
||||
UNIXISH_PATTERN = re.compile(r"[\s~]")
|
||||
WINDOWS_PATTERN = re.compile(r"\s")
|
||||
|
||||
@classmethod
|
||||
def _has_special_chars(cls, file_path: str) -> bool:
|
||||
if platform.system() == "Windows":
|
||||
return bool(cls.WINDOWS_PATTERN.search(file_path))
|
||||
return bool(cls.UNIXISH_PATTERN.search(file_path))
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self._tmpdir = None
|
||||
|
||||
if self._has_special_chars(os.path.abspath(file_path)):
|
||||
base_path, file_name = os.path.split(os.path.abspath(file_path))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
try:
|
||||
short_base_path = windows_short_path(base_path)
|
||||
if os.path.exists(short_base_path):
|
||||
file_path = os.path.join(short_base_path, file_name)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if self._has_special_chars(os.path.abspath(file_path)):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
if self._has_special_chars(tmpdir):
|
||||
raise RuntimeError(
|
||||
'Unable to generate temporary path without spaces or '
|
||||
'special characters! \n Please move your stan file to a '
|
||||
'location without spaces or special characters.'
|
||||
)
|
||||
|
||||
_, path = tempfile.mkstemp(suffix='.stan', dir=tmpdir)
|
||||
|
||||
shutil.copy(file_path, path)
|
||||
self._path = path
|
||||
self._tmpdir = tmpdir
|
||||
else:
|
||||
self._path = file_path
|
||||
|
||||
def __enter__(self) -> Tuple[str, bool]:
|
||||
return self._path, self._tmpdir is not None
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore
|
||||
if self._tmpdir:
|
||||
shutil.rmtree(self._tmpdir, ignore_errors=True)
|
||||
@ -0,0 +1,6 @@
|
||||
"""
|
||||
Delegated to stanio - https://github.com/WardBrian/stanio
|
||||
"""
|
||||
from stanio import write_stan_json
|
||||
|
||||
__all__ = ['write_stan_json']
|
||||
@ -0,0 +1,25 @@
|
||||
"""
|
||||
CmdStanPy logging
|
||||
"""
|
||||
import functools
|
||||
import logging
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_logger() -> logging.Logger:
|
||||
"""cmdstanpy logger"""
|
||||
logger = logging.getLogger('cmdstanpy')
|
||||
if len(logger.handlers) == 0:
|
||||
# send all messages to handlers
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# add a default handler to the logger to INFO and higher
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.INFO)
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
"%H:%M:%S",
|
||||
)
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
486
.venv/lib/python3.12/site-packages/cmdstanpy/utils/stancsv.py
Normal file
486
.venv/lib/python3.12/site-packages/cmdstanpy/utils/stancsv.py
Normal file
@ -0,0 +1,486 @@
|
||||
"""
|
||||
Utility functions for reading the Stan CSV format
|
||||
"""
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from typing import Any, Dict, List, MutableMapping, Optional, TextIO, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP
|
||||
|
||||
|
||||
def check_sampler_csv(
|
||||
path: str,
|
||||
is_fixed_param: bool = False,
|
||||
iter_sampling: Optional[int] = None,
|
||||
iter_warmup: Optional[int] = None,
|
||||
save_warmup: bool = False,
|
||||
thin: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Capture essential config, shape from stan_csv file."""
|
||||
meta = scan_sampler_csv(path, is_fixed_param)
|
||||
if thin is None:
|
||||
thin = _CMDSTAN_THIN
|
||||
elif thin > _CMDSTAN_THIN:
|
||||
if 'thin' not in meta:
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, '
|
||||
'config error, expected thin = {}'.format(path, thin)
|
||||
)
|
||||
if meta['thin'] != thin:
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, '
|
||||
'config error, expected thin = {}, found {}'.format(
|
||||
path, thin, meta['thin']
|
||||
)
|
||||
)
|
||||
draws_sampling = iter_sampling
|
||||
if draws_sampling is None:
|
||||
draws_sampling = _CMDSTAN_SAMPLING
|
||||
draws_warmup = iter_warmup
|
||||
if draws_warmup is None:
|
||||
draws_warmup = _CMDSTAN_WARMUP
|
||||
draws_warmup = int(math.ceil(draws_warmup / thin))
|
||||
draws_sampling = int(math.ceil(draws_sampling / thin))
|
||||
if meta['draws_sampling'] != draws_sampling:
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, expected {} draws, found {}'.format(
|
||||
path, draws_sampling, meta['draws_sampling']
|
||||
)
|
||||
)
|
||||
if save_warmup:
|
||||
if not ('save_warmup' in meta and meta['save_warmup'] in (1, 'true')):
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, '
|
||||
'config error, expected save_warmup = 1'.format(path)
|
||||
)
|
||||
if meta['draws_warmup'] != draws_warmup:
|
||||
raise ValueError(
|
||||
'bad Stan CSV file {}, '
|
||||
'expected {} warmup draws, found {}'.format(
|
||||
path, draws_warmup, meta['draws_warmup']
|
||||
)
|
||||
)
|
||||
return meta
|
||||
|
||||
|
||||
def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
|
||||
"""Process sampler stan_csv output file line by line."""
|
||||
dict: Dict[str, Any] = {}
|
||||
lineno = 0
|
||||
with open(path, 'r') as fd:
|
||||
try:
|
||||
lineno = scan_config(fd, dict, lineno)
|
||||
lineno = scan_column_names(fd, dict, lineno)
|
||||
if not is_fixed_param:
|
||||
lineno = scan_warmup_iters(fd, dict, lineno)
|
||||
lineno = scan_hmc_params(fd, dict, lineno)
|
||||
lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param)
|
||||
except ValueError as e:
|
||||
raise ValueError("Error in reading csv file: " + path) from e
|
||||
return dict
|
||||
|
||||
|
||||
def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
|
||||
"""Process optimizer stan_csv output file line by line."""
|
||||
dict: Dict[str, Any] = {}
|
||||
lineno = 0
|
||||
# scan to find config, header, num saved iters
|
||||
with open(path, 'r') as fd:
|
||||
lineno = scan_config(fd, dict, lineno)
|
||||
lineno = scan_column_names(fd, dict, lineno)
|
||||
iters = 0
|
||||
for line in fd:
|
||||
iters += 1
|
||||
if save_iters:
|
||||
all_iters: np.ndarray = np.empty(
|
||||
(iters, len(dict['column_names'])), dtype=float, order='F'
|
||||
)
|
||||
# rescan to capture estimates
|
||||
with open(path, 'r') as fd:
|
||||
for i in range(lineno):
|
||||
fd.readline()
|
||||
for i in range(iters):
|
||||
line = fd.readline().strip()
|
||||
if len(line) < 1:
|
||||
raise ValueError(
|
||||
'cannot parse CSV file {}, error at line {}'.format(
|
||||
path, lineno + i
|
||||
)
|
||||
)
|
||||
xs = line.split(',')
|
||||
if save_iters:
|
||||
all_iters[i, :] = [float(x) for x in xs]
|
||||
if i == iters - 1:
|
||||
mle: np.ndarray = np.array(xs, dtype=float)
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
dict['mle'] = mle
|
||||
if save_iters:
|
||||
dict['all_iters'] = all_iters
|
||||
return dict
|
||||
|
||||
|
||||
def scan_generic_csv(path: str) -> Dict[str, Any]:
|
||||
"""Process laplace stan_csv output file line by line."""
|
||||
dict: Dict[str, Any] = {}
|
||||
lineno = 0
|
||||
with open(path, 'r') as fd:
|
||||
lineno = scan_config(fd, dict, lineno)
|
||||
lineno = scan_column_names(fd, dict, lineno)
|
||||
return dict
|
||||
|
||||
|
||||
def scan_variational_csv(path: str) -> Dict[str, Any]:
|
||||
"""Process advi stan_csv output file line by line."""
|
||||
dict: Dict[str, Any] = {}
|
||||
lineno = 0
|
||||
with open(path, 'r') as fd:
|
||||
lineno = scan_config(fd, dict, lineno)
|
||||
lineno = scan_column_names(fd, dict, lineno)
|
||||
line = fd.readline().lstrip(' #\t').rstrip()
|
||||
lineno += 1
|
||||
if line.startswith('Stepsize adaptation complete.'):
|
||||
line = fd.readline().lstrip(' #\t\n')
|
||||
lineno += 1
|
||||
if not line.startswith('eta'):
|
||||
raise ValueError(
|
||||
'line {}: expecting eta, found:\n\t "{}"'.format(
|
||||
lineno, line
|
||||
)
|
||||
)
|
||||
_, eta = line.split('=')
|
||||
dict['eta'] = float(eta)
|
||||
line = fd.readline().lstrip(' #\t\n')
|
||||
lineno += 1
|
||||
xs = line.split(',')
|
||||
variational_mean = [float(x) for x in xs]
|
||||
dict['variational_mean'] = np.array(variational_mean)
|
||||
dict['variational_sample'] = pd.read_csv(
|
||||
path,
|
||||
comment='#',
|
||||
skiprows=lineno,
|
||||
header=None,
|
||||
float_precision='high',
|
||||
).to_numpy()
|
||||
return dict
|
||||
|
||||
|
||||
def scan_config(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int:
|
||||
"""
|
||||
Scan initial stan_csv file comments lines and
|
||||
save non-default configuration information to config_dict.
|
||||
"""
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
while len(line) > 0 and line.startswith('#'):
|
||||
lineno += 1
|
||||
if line.endswith('(Default)'):
|
||||
line = line.replace('(Default)', '')
|
||||
line = line.lstrip(' #\t')
|
||||
key_val = line.split('=')
|
||||
if len(key_val) == 2:
|
||||
if key_val[0].strip() == 'file' and not key_val[1].endswith('csv'):
|
||||
config_dict['data_file'] = key_val[1].strip()
|
||||
elif key_val[0].strip() != 'file':
|
||||
raw_val = key_val[1].strip()
|
||||
val: Union[int, float, str]
|
||||
try:
|
||||
val = int(raw_val)
|
||||
except ValueError:
|
||||
try:
|
||||
val = float(raw_val)
|
||||
except ValueError:
|
||||
if raw_val == "true":
|
||||
val = 1
|
||||
elif raw_val == "false":
|
||||
val = 0
|
||||
else:
|
||||
val = raw_val
|
||||
config_dict[key_val[0].strip()] = val
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
fd.seek(cur_pos)
|
||||
return lineno
|
||||
|
||||
|
||||
def scan_warmup_iters(
|
||||
fd: TextIO, config_dict: Dict[str, Any], lineno: int
|
||||
) -> int:
|
||||
"""
|
||||
Check warmup iterations, if any.
|
||||
"""
|
||||
if 'save_warmup' not in config_dict:
|
||||
return lineno
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
draws_found = 0
|
||||
while len(line) > 0 and not line.startswith('#'):
|
||||
lineno += 1
|
||||
draws_found += 1
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
fd.seek(cur_pos)
|
||||
config_dict['draws_warmup'] = draws_found
|
||||
return lineno
|
||||
|
||||
|
||||
def scan_column_names(
|
||||
fd: TextIO, config_dict: MutableMapping[str, Any], lineno: int
|
||||
) -> int:
|
||||
"""
|
||||
Process columns header, add to config_dict as 'column_names'
|
||||
"""
|
||||
line = fd.readline().strip()
|
||||
lineno += 1
|
||||
config_dict['raw_header'] = line.strip()
|
||||
names = line.split(',')
|
||||
config_dict['column_names'] = tuple(munge_varnames(names))
|
||||
return lineno
|
||||
|
||||
|
||||
def munge_varname(name: str) -> str:
|
||||
if '.' not in name and ':' not in name:
|
||||
return name
|
||||
|
||||
tuple_parts = name.split(':')
|
||||
for i, part in enumerate(tuple_parts):
|
||||
if '.' not in part:
|
||||
continue
|
||||
part = part.replace('.', '[', 1)
|
||||
part = part.replace('.', ',')
|
||||
part += ']'
|
||||
tuple_parts[i] = part
|
||||
|
||||
return '.'.join(tuple_parts)
|
||||
|
||||
|
||||
def munge_varnames(names: List[str]) -> List[str]:
|
||||
"""
|
||||
Change formatting for indices of container var elements
|
||||
from use of dot separator to array-like notation, e.g.,
|
||||
rewrite label ``y_forecast.2.4`` to ``y_forecast[2,4]``.
|
||||
"""
|
||||
if names is None:
|
||||
raise ValueError('missing argument "names"')
|
||||
return [munge_varname(name) for name in names]
|
||||
|
||||
|
||||
def scan_hmc_params(
|
||||
fd: TextIO, config_dict: Dict[str, Any], lineno: int
|
||||
) -> int:
|
||||
"""
|
||||
Scan step size, metric from stan_csv file comment lines.
|
||||
"""
|
||||
metric = config_dict['metric']
|
||||
line = fd.readline().strip()
|
||||
lineno += 1
|
||||
if not line == '# Adaptation terminated':
|
||||
raise ValueError(
|
||||
'line {}: expecting metric, found:\n\t "{}"'.format(lineno, line)
|
||||
)
|
||||
line = fd.readline().strip()
|
||||
lineno += 1
|
||||
label, step_size = line.split('=')
|
||||
if not label.startswith('# Step size'):
|
||||
raise ValueError(
|
||||
'line {}: expecting step size, '
|
||||
'found:\n\t "{}"'.format(lineno, line)
|
||||
)
|
||||
try:
|
||||
float(step_size.strip())
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
'line {}: invalid step size: {}'.format(lineno, step_size)
|
||||
) from e
|
||||
before_metric = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
lineno += 1
|
||||
if metric == 'unit_e':
|
||||
if line.startswith("# No free parameters"):
|
||||
return lineno
|
||||
else:
|
||||
fd.seek(before_metric)
|
||||
return lineno - 1
|
||||
|
||||
if not (
|
||||
(
|
||||
metric == 'diag_e'
|
||||
and line == '# Diagonal elements of inverse mass matrix:'
|
||||
)
|
||||
or (
|
||||
metric == 'dense_e' and line == '# Elements of inverse mass matrix:'
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
'line {}: invalid or missing mass matrix '
|
||||
'specification'.format(lineno)
|
||||
)
|
||||
line = fd.readline().lstrip(' #\t')
|
||||
lineno += 1
|
||||
num_unconstrained_params = len(line.split(','))
|
||||
if metric == 'diag_e':
|
||||
return lineno
|
||||
else:
|
||||
for _ in range(1, num_unconstrained_params):
|
||||
line = fd.readline().lstrip(' #\t')
|
||||
lineno += 1
|
||||
if len(line.split(',')) != num_unconstrained_params:
|
||||
raise ValueError(
|
||||
'line {}: invalid or missing mass matrix '
|
||||
'specification'.format(lineno)
|
||||
)
|
||||
return lineno
|
||||
|
||||
|
||||
def scan_sampling_iters(
|
||||
fd: TextIO, config_dict: Dict[str, Any], lineno: int, is_fixed_param: bool
|
||||
) -> int:
|
||||
"""
|
||||
Parse sampling iteration, save number of iterations to config_dict.
|
||||
Also save number of divergences, max_treedepth hits
|
||||
"""
|
||||
draws_found = 0
|
||||
num_cols = len(config_dict['column_names'])
|
||||
if not is_fixed_param:
|
||||
idx_divergent = config_dict['column_names'].index('divergent__')
|
||||
idx_treedepth = config_dict['column_names'].index('treedepth__')
|
||||
max_treedepth = config_dict['max_depth']
|
||||
ct_divergences = 0
|
||||
ct_max_treedepth = 0
|
||||
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
while len(line) > 0 and not line.startswith('#'):
|
||||
lineno += 1
|
||||
draws_found += 1
|
||||
data = line.split(',')
|
||||
if len(data) != num_cols:
|
||||
raise ValueError(
|
||||
'line {}: bad draw, expecting {} items, found {}\n'.format(
|
||||
lineno, num_cols, len(line.split(','))
|
||||
)
|
||||
+ 'This error could be caused by running out of disk space.\n'
|
||||
'Try clearing up TEMP or setting output_dir to a path'
|
||||
' on another drive.',
|
||||
)
|
||||
cur_pos = fd.tell()
|
||||
line = fd.readline().strip()
|
||||
if not is_fixed_param:
|
||||
ct_divergences += int(data[idx_divergent]) # type: ignore
|
||||
if int(data[idx_treedepth]) == max_treedepth: # type: ignore
|
||||
ct_max_treedepth += 1
|
||||
|
||||
fd.seek(cur_pos)
|
||||
config_dict['draws_sampling'] = draws_found
|
||||
if not is_fixed_param:
|
||||
config_dict['ct_divergences'] = ct_divergences
|
||||
config_dict['ct_max_treedepth'] = ct_max_treedepth
|
||||
return lineno
|
||||
|
||||
|
||||
def read_metric(path: str) -> List[int]:
|
||||
"""
|
||||
Read metric file in JSON or Rdump format.
|
||||
Return dimensions of entry "inv_metric".
|
||||
"""
|
||||
if path.endswith('.json'):
|
||||
with open(path, 'r') as fd:
|
||||
metric_dict = json.load(fd)
|
||||
if 'inv_metric' in metric_dict:
|
||||
dims_np: np.ndarray = np.asarray(metric_dict['inv_metric'])
|
||||
return list(dims_np.shape)
|
||||
else:
|
||||
raise ValueError(
|
||||
'metric file {}, bad or missing'
|
||||
' entry "inv_metric"'.format(path)
|
||||
)
|
||||
else:
|
||||
dims = list(read_rdump_metric(path))
|
||||
if dims is None:
|
||||
raise ValueError(
|
||||
'metric file {}, bad or missing'
|
||||
' entry "inv_metric"'.format(path)
|
||||
)
|
||||
return dims
|
||||
|
||||
|
||||
def read_rdump_metric(path: str) -> List[int]:
|
||||
"""
|
||||
Find dimensions of variable named 'inv_metric' in Rdump data file.
|
||||
"""
|
||||
metric_dict = rload(path)
|
||||
if metric_dict is None or not (
|
||||
'inv_metric' in metric_dict
|
||||
and isinstance(metric_dict['inv_metric'], np.ndarray)
|
||||
):
|
||||
raise ValueError(
|
||||
'metric file {}, bad or missing entry "inv_metric"'.format(path)
|
||||
)
|
||||
return list(metric_dict['inv_metric'].shape)
|
||||
|
||||
|
||||
def rload(fname: str) -> Optional[Dict[str, Union[int, float, np.ndarray]]]:
|
||||
"""Parse data and parameter variable values from an R dump format file.
|
||||
This parser only supports the subset of R dump data as described
|
||||
in the "Dump Data Format" section of the CmdStan manual, i.e.,
|
||||
scalar, vector, matrix, and array data types.
|
||||
"""
|
||||
data_dict = {}
|
||||
with open(fname, 'r') as fd:
|
||||
lines = fd.readlines()
|
||||
# Variable data may span multiple lines, parse accordingly
|
||||
idx = 0
|
||||
while idx < len(lines) and '<-' not in lines[idx]:
|
||||
idx += 1
|
||||
if idx == len(lines):
|
||||
return None
|
||||
start_idx = idx
|
||||
idx += 1
|
||||
while True:
|
||||
while idx < len(lines) and '<-' not in lines[idx]:
|
||||
idx += 1
|
||||
next_var = idx
|
||||
var_data = ''.join(lines[start_idx:next_var]).replace('\n', '')
|
||||
lhs, rhs = [item.strip() for item in var_data.split('<-')]
|
||||
lhs = lhs.replace('"', '') # strip optional Jags double quotes
|
||||
rhs = rhs.replace('L', '') # strip R long int qualifier
|
||||
data_dict[lhs] = parse_rdump_value(rhs)
|
||||
if idx == len(lines):
|
||||
break
|
||||
start_idx = next_var
|
||||
idx += 1
|
||||
return data_dict
|
||||
|
||||
|
||||
def parse_rdump_value(rhs: str) -> Union[int, float, np.ndarray]:
|
||||
"""Process right hand side of Rdump variable assignment statement.
|
||||
Value is either scalar, vector, or multi-dim structure.
|
||||
Use regex to capture structure values, dimensions.
|
||||
"""
|
||||
pat = re.compile(
|
||||
r'structure\(\s*c\((?P<vals>[^)]*)\)'
|
||||
r'(,\s*\.Dim\s*=\s*c\s*\((?P<dims>[^)]*)\s*\))?\)'
|
||||
)
|
||||
val: Union[int, float, np.ndarray]
|
||||
try:
|
||||
if rhs.startswith('structure'):
|
||||
parse = pat.match(rhs)
|
||||
if parse is None or parse.group('vals') is None:
|
||||
raise ValueError(rhs)
|
||||
vals = [float(v) for v in parse.group('vals').split(',')]
|
||||
val = np.array(vals, order='F')
|
||||
if parse.group('dims') is not None:
|
||||
dims = [int(v) for v in parse.group('dims').split(',')]
|
||||
val = np.array(vals).reshape(dims, order='F')
|
||||
elif rhs.startswith('c(') and rhs.endswith(')'):
|
||||
val = np.array([float(item) for item in rhs[2:-1].split(',')])
|
||||
elif '.' in rhs or 'e' in rhs:
|
||||
val = float(rhs)
|
||||
else:
|
||||
val = int(rhs)
|
||||
except TypeError as e:
|
||||
raise ValueError('bad value in Rdump file: {}'.format(rhs)) from e
|
||||
return val
|
||||
@ -0,0 +1 @@
|
||||
pip
|
||||
1815
.venv/lib/python3.12/site-packages/holidays-0.77.dist-info/METADATA
Normal file
1815
.venv/lib/python3.12/site-packages/holidays-0.77.dist-info/METADATA
Normal file
File diff suppressed because it is too large
Load Diff
1111
.venv/lib/python3.12/site-packages/holidays-0.77.dist-info/RECORD
Normal file
1111
.venv/lib/python3.12/site-packages/holidays-0.77.dist-info/RECORD
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (80.9.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
@ -0,0 +1,162 @@
|
||||
Aaron Picht
|
||||
Aart Goossens
|
||||
Abdelkhalek Boukli Hacene
|
||||
Abheelash Mishra
|
||||
Akos Furton
|
||||
Alejandro Antunes
|
||||
Aleksei Zhuchkov
|
||||
Alexander Schulze
|
||||
Alexandre Carvalho
|
||||
Alexei Mikhailov
|
||||
Anders Wenhaug
|
||||
Andrei Klimenko
|
||||
Andres Marrugo
|
||||
Ankush Kapoor
|
||||
Anon Kangpanich
|
||||
Anthony Rose
|
||||
Anton Daitche
|
||||
Arjun Anandkumar
|
||||
Arkadii Yakovets
|
||||
Artem Tserekh
|
||||
Ary Hauffe Neto
|
||||
Bailey Thompson
|
||||
Ben Collerson
|
||||
Ben Letham
|
||||
Benjamin Lucas Wacha
|
||||
Bernhard M. Wiedemann
|
||||
Carlos Rocha
|
||||
Chanran Kim
|
||||
Chris McKeague
|
||||
Chris Turra
|
||||
Christian Alexander
|
||||
Colin Watson
|
||||
Dan Gentry
|
||||
Daniel Musketa
|
||||
Daniël Niemeijer
|
||||
David Hotham
|
||||
Devaraj K
|
||||
Diego Rosaperez
|
||||
Diogo Rosa
|
||||
Dorian Monnier
|
||||
Douglas Franklin
|
||||
Eden Juscelino
|
||||
Edison Robles
|
||||
Edward Betts
|
||||
Eldar Mustafayev
|
||||
Emmanuel Arias
|
||||
Eugenio Panadero Maciá
|
||||
Fabian Affolter
|
||||
Felix Lee
|
||||
Filip Bednárik
|
||||
Firas Kafri
|
||||
Gabriel L Martinez
|
||||
Gabriel Trabanco
|
||||
Giedrius Mauza
|
||||
Gordon Inggs
|
||||
Greg Rafferty
|
||||
Győző Papp
|
||||
Heikki Orsila
|
||||
Henrik Sozzi
|
||||
Hiroki Kawahara
|
||||
Hugh McNamara
|
||||
Hugo van Kemenade
|
||||
Isabelle COWAN-BERGMAN
|
||||
Jacky Han
|
||||
Jacob Punter
|
||||
Jaemin Kim
|
||||
Jahir Fiquitiva
|
||||
Jakob M. Kjær
|
||||
Jan Pipek
|
||||
Jason Jensen
|
||||
Jeremy Chrimes
|
||||
Jerry Agbesi
|
||||
John Laswell
|
||||
Joost van Driel
|
||||
Jorge Cadena Argote
|
||||
Jose Riha
|
||||
Joshua Adelman
|
||||
Joël van Amerongen
|
||||
Julian Broudou
|
||||
Jung Dong Ho
|
||||
Justin Asfour
|
||||
Kamil Leduchowski
|
||||
Kate Golovanova
|
||||
Kelsey Karin Hawley
|
||||
Koert van der Veer
|
||||
Koki Nomura
|
||||
Kriti Birda
|
||||
Laurent Comparet
|
||||
Lucca Augusto
|
||||
Maina Kamau
|
||||
Malthe Borch
|
||||
Marek Šuppa
|
||||
Martin Becker
|
||||
Martin Thurau
|
||||
Matheus Oliveira
|
||||
Maurizio Montel
|
||||
Max Härtwig
|
||||
Michael Thessel
|
||||
Mike Borsetti
|
||||
Mike Polyakovsky
|
||||
Miroslav Šedivý
|
||||
Monde Sinxi
|
||||
Nalin Gupta
|
||||
Nataliia Dmytriievska
|
||||
Nate Harris
|
||||
Nathan Ell
|
||||
Nicholas Spagnoletti
|
||||
Nico Albers
|
||||
Olivier Iffrig
|
||||
Ondřej Nový
|
||||
Osayd Abdu
|
||||
Oscar Romero
|
||||
Pablo Merino
|
||||
Panpakorn Siripanich
|
||||
Patrick Nicholson
|
||||
Paulo Orrock
|
||||
Pavel Sofroniev
|
||||
Pedro Baptista
|
||||
Peter Zsak
|
||||
Pieter van der Westhuizen
|
||||
Piotr Staniów
|
||||
Platon Supranovich
|
||||
Prateekshit Jaiswal
|
||||
Raphael Borg Ellul Vincenti
|
||||
Raychel Mattheeuw
|
||||
Reinaldo Ramos
|
||||
Robert Frazier
|
||||
Robert Schmidtke
|
||||
Robert Tran
|
||||
Robin Emeršič
|
||||
Roshan Pradhan
|
||||
Ryan McCrory
|
||||
Sam Tregar
|
||||
Samman Sarkar
|
||||
Santiago Feliu
|
||||
Sergi Almacellas Abellana
|
||||
Sergio Mayoral Martinez
|
||||
Serhii Murza
|
||||
Shalom Donga
|
||||
Shaurya Uppal
|
||||
Sho Hirose
|
||||
Shreyansh Pande
|
||||
Shreyas Smarth
|
||||
Simon Gurcke
|
||||
Sindhura Kumbakonam Subramanian
|
||||
Sugato Ray
|
||||
Sylvain Pasche
|
||||
Sylvia van Os
|
||||
Søren Klintrup
|
||||
Takeshi Osoekawa
|
||||
Tasnim Nishat Islam
|
||||
Tewodros Meshesha
|
||||
Thomas Bøvith
|
||||
Tommy Sparber
|
||||
Tudor Văran
|
||||
Victor Luna
|
||||
Victor Miti
|
||||
Ville Skyttä
|
||||
Vilmos Prokaj
|
||||
Vu Nhat Chuong
|
||||
Wasif Shahzad
|
||||
Youhei Sakurai
|
||||
@ -0,0 +1,23 @@
|
||||
Copyright (c) Vacanza Team and individual contributors (see CONTRIBUTORS file)
|
||||
Copyright (c) dr-prodigy <dr.prodigy.github@gmail.com>, 2017-2023
|
||||
Copyright (c) ryanss <ryanssdev@icloud.com>, 2014-2017
|
||||
|
||||
All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
@ -0,0 +1 @@
|
||||
holidays
|
||||
22
.venv/lib/python3.12/site-packages/holidays/__init__.py
Normal file
22
.venv/lib/python3.12/site-packages/holidays/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# holidays
|
||||
# --------
|
||||
# A fast, efficient Python library for generating country, province and state
|
||||
# specific sets of holidays on the fly. It aims to make determining whether a
|
||||
# specific date is a holiday as fast and flexible as possible.
|
||||
#
|
||||
# Authors: Vacanza Team and individual contributors (see CONTRIBUTORS file)
|
||||
# dr-prodigy <dr.prodigy.github@gmail.com> (c) 2017-2023
|
||||
# ryanss <ryanssdev@icloud.com> (c) 2014-2017
|
||||
# Website: https://github.com/vacanza/holidays
|
||||
# License: MIT (see LICENSE file)
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from holidays.constants import *
|
||||
from holidays.holiday_base import *
|
||||
from holidays.registry import EntityLoader
|
||||
from holidays.utils import *
|
||||
from holidays.version import __version__ # noqa: F401
|
||||
|
||||
EntityLoader.load("countries", globals())
|
||||
EntityLoader.load("financial", globals())
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,28 @@
|
||||
# holidays
|
||||
# --------
|
||||
# A fast, efficient Python library for generating country, province and state
|
||||
# specific sets of holidays on the fly. It aims to make determining whether a
|
||||
# specific date is a holiday as fast and flexible as possible.
|
||||
#
|
||||
# Authors: Vacanza Team and individual contributors (see CONTRIBUTORS file)
|
||||
# dr-prodigy <dr.prodigy.github@gmail.com> (c) 2017-2023
|
||||
# ryanss <ryanssdev@icloud.com> (c) 2014-2017
|
||||
# Website: https://github.com/vacanza/holidays
|
||||
# License: MIT (see LICENSE file)
|
||||
|
||||
# ruff: noqa: F401
|
||||
|
||||
from holidays.calendars.balinese_saka import _BalineseSakaLunar
|
||||
from holidays.calendars.buddhist import _BuddhistLunisolar, _CustomBuddhistHolidays
|
||||
from holidays.calendars.chinese import _ChineseLunisolar, _CustomChineseHolidays
|
||||
from holidays.calendars.custom import _CustomCalendar
|
||||
from holidays.calendars.gregorian import GREGORIAN_CALENDAR
|
||||
from holidays.calendars.hebrew import _HebrewLunisolar
|
||||
from holidays.calendars.hindu import _CustomHinduHolidays, _HinduLunisolar
|
||||
from holidays.calendars.islamic import _CustomIslamicHolidays, _IslamicLunar
|
||||
from holidays.calendars.julian import JULIAN_CALENDAR
|
||||
from holidays.calendars.julian_revised import JULIAN_REVISED_CALENDAR
|
||||
from holidays.calendars.mongolian import _CustomMongolianHolidays, _MongolianLunisolar
|
||||
from holidays.calendars.persian import _Persian
|
||||
from holidays.calendars.sinhala import _SinhalaLunar, _CustomSinhalaHolidays
|
||||
from holidays.calendars.thai import _ThaiLunisolar, KHMER_CALENDAR, THAI_CALENDAR
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,112 @@
|
||||
# holidays
|
||||
# --------
|
||||
# A fast, efficient Python library for generating country, province and state
|
||||
# specific sets of holidays on the fly. It aims to make determining whether a
|
||||
# specific date is a holiday as fast and flexible as possible.
|
||||
#
|
||||
# Authors: Vacanza Team and individual contributors (see CONTRIBUTORS file)
|
||||
# dr-prodigy <dr.prodigy.github@gmail.com> (c) 2017-2023
|
||||
# ryanss <ryanssdev@icloud.com> (c) 2014-2017
|
||||
# Website: https://github.com/vacanza/holidays
|
||||
# License: MIT (see LICENSE file)
|
||||
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from holidays.calendars.gregorian import MAR, APR
|
||||
|
||||
NYEPI = "NYEPI"
|
||||
|
||||
|
||||
class _BalineseSakaLunar:
|
||||
"""
|
||||
Balinese Saka lunar calendar.
|
||||
|
||||
The Balinese saka calendar is one of two calendars used on the Indonesian island
|
||||
of Bali. Unlike the 210-day pawukon calendar, it is based on the phases of the Moon,
|
||||
and is approximately the same length as the tropical year (solar year, Gregorian year).
|
||||
https://en.wikipedia.org/wiki/Balinese_saka_calendar
|
||||
"""
|
||||
|
||||
NYEPI_DATES = {
|
||||
1983: (MAR, 15),
|
||||
1984: (MAR, 4),
|
||||
1985: (MAR, 22),
|
||||
1986: (MAR, 12),
|
||||
1987: (MAR, 31),
|
||||
1988: (MAR, 19),
|
||||
1989: (MAR, 9),
|
||||
1990: (MAR, 27),
|
||||
1991: (MAR, 17),
|
||||
1992: (MAR, 5),
|
||||
1993: (MAR, 24),
|
||||
1994: (MAR, 12),
|
||||
1995: (APR, 1),
|
||||
1996: (MAR, 21),
|
||||
1997: (APR, 9),
|
||||
1998: (MAR, 29),
|
||||
1999: (MAR, 18),
|
||||
2000: (APR, 4),
|
||||
2001: (MAR, 25),
|
||||
2002: (APR, 13),
|
||||
2003: (APR, 2),
|
||||
2004: (MAR, 22),
|
||||
2005: (MAR, 11),
|
||||
2006: (MAR, 30),
|
||||
2007: (MAR, 19),
|
||||
2008: (MAR, 7),
|
||||
2009: (MAR, 26),
|
||||
2010: (MAR, 16),
|
||||
2011: (MAR, 5),
|
||||
2012: (MAR, 23),
|
||||
2013: (MAR, 12),
|
||||
2014: (MAR, 31),
|
||||
2015: (MAR, 21),
|
||||
2016: (MAR, 9),
|
||||
2017: (MAR, 28),
|
||||
2018: (MAR, 17),
|
||||
2019: (MAR, 7),
|
||||
2020: (MAR, 25),
|
||||
2021: (MAR, 14),
|
||||
2022: (MAR, 3),
|
||||
2023: (MAR, 22),
|
||||
2024: (MAR, 11),
|
||||
2025: (MAR, 29),
|
||||
2026: (MAR, 19),
|
||||
2027: (MAR, 8),
|
||||
2028: (MAR, 26),
|
||||
2029: (MAR, 15),
|
||||
2030: (MAR, 5),
|
||||
2031: (MAR, 24),
|
||||
2032: (MAR, 12),
|
||||
2033: (MAR, 31),
|
||||
2034: (MAR, 20),
|
||||
2035: (MAR, 10),
|
||||
2036: (MAR, 28),
|
||||
2037: (MAR, 17),
|
||||
2038: (MAR, 6),
|
||||
2039: (MAR, 25),
|
||||
2040: (MAR, 14),
|
||||
2041: (MAR, 3),
|
||||
2042: (MAR, 22),
|
||||
2043: (MAR, 11),
|
||||
2044: (MAR, 29),
|
||||
2045: (MAR, 19),
|
||||
2046: (MAR, 8),
|
||||
2047: (MAR, 27),
|
||||
2048: (MAR, 15),
|
||||
2049: (MAR, 5),
|
||||
2050: (MAR, 24),
|
||||
}
|
||||
|
||||
def _get_holiday(self, holiday: str, year: int) -> Optional[date]:
|
||||
dt = getattr(self, f"{holiday}_DATES", {}).get(year, ())
|
||||
return date(year, *dt) if dt else None
|
||||
|
||||
def nyepi_date(self, year: int) -> Optional[date]:
|
||||
"""
|
||||
Data References:
|
||||
* [1983-2025](https://id.wikipedia.org/wiki/Indonesia_dalam_tahun_1983)
|
||||
* [2020-2050](https://web.archive.org/web/20240718011857/https://www.balitrips.com/balinese-temples-ceremony)
|
||||
"""
|
||||
return self._get_holiday(NYEPI, year)
|
||||
@ -0,0 +1,444 @@
|
||||
# holidays
|
||||
# --------
|
||||
# A fast, efficient Python library for generating country, province and state
|
||||
# specific sets of holidays on the fly. It aims to make determining whether a
|
||||
# specific date is a holiday as fast and flexible as possible.
|
||||
#
|
||||
# Authors: Vacanza Team and individual contributors (see CONTRIBUTORS file)
|
||||
# dr-prodigy <dr.prodigy.github@gmail.com> (c) 2017-2023
|
||||
# ryanss <ryanssdev@icloud.com> (c) 2014-2017
|
||||
# Website: https://github.com/vacanza/holidays
|
||||
# License: MIT (see LICENSE file)
|
||||
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from holidays.calendars.custom import _CustomCalendar
|
||||
from holidays.calendars.gregorian import MAY, JUN
|
||||
|
||||
VESAK = "VESAK"
|
||||
VESAK_MAY = "VESAK_MAY"
|
||||
|
||||
|
||||
class _BuddhistLunisolar:
|
||||
VESAK_DATES = {
|
||||
1901: (JUN, 1),
|
||||
1902: (MAY, 22),
|
||||
1903: (MAY, 11),
|
||||
1904: (MAY, 29),
|
||||
1905: (MAY, 18),
|
||||
1906: (MAY, 8),
|
||||
1907: (MAY, 26),
|
||||
1908: (MAY, 14),
|
||||
1909: (JUN, 2),
|
||||
1910: (MAY, 23),
|
||||
1911: (MAY, 13),
|
||||
1912: (MAY, 31),
|
||||
1913: (MAY, 20),
|
||||
1914: (MAY, 9),
|
||||
1915: (MAY, 28),
|
||||
1916: (MAY, 16),
|
||||
1917: (JUN, 4),
|
||||
1918: (MAY, 24),
|
||||
1919: (MAY, 14),
|
||||
1920: (JUN, 1),
|
||||
1921: (MAY, 22),
|
||||
1922: (MAY, 11),
|
||||
1923: (MAY, 30),
|
||||
1924: (MAY, 18),
|
||||
1925: (MAY, 7),
|
||||
1926: (MAY, 26),
|
||||
1927: (MAY, 15),
|
||||
1928: (JUN, 2),
|
||||
1929: (MAY, 23),
|
||||
1930: (MAY, 13),
|
||||
1931: (MAY, 31),
|
||||
1932: (MAY, 20),
|
||||
1933: (MAY, 9),
|
||||
1934: (MAY, 27),
|
||||
1935: (MAY, 17),
|
||||
1936: (JUN, 4),
|
||||
1937: (MAY, 24),
|
||||
1938: (MAY, 14),
|
||||
1939: (JUN, 2),
|
||||
1940: (MAY, 21),
|
||||
1941: (MAY, 10),
|
||||
1942: (MAY, 29),
|
||||
1943: (MAY, 18),
|
||||
1944: (MAY, 7),
|
||||
1945: (MAY, 26),
|
||||
1946: (MAY, 15),
|
||||
1947: (JUN, 3),
|
||||
1948: (MAY, 23),
|
||||
1949: (MAY, 12),
|
||||
1950: (MAY, 31),
|
||||
1951: (MAY, 20),
|
||||
1952: (MAY, 8),
|
||||
1953: (MAY, 27),
|
||||
1954: (MAY, 17),
|
||||
1955: (JUN, 5),
|
||||
1956: (MAY, 24),
|
||||
1957: (MAY, 14),
|
||||
1958: (JUN, 2),
|
||||
1959: (MAY, 22),
|
||||
1960: (MAY, 10),
|
||||
1961: (MAY, 29),
|
||||
1962: (MAY, 18),
|
||||
1963: (MAY, 8),
|
||||
1964: (MAY, 26),
|
||||
1965: (MAY, 15),
|
||||
1966: (JUN, 3),
|
||||
1967: (MAY, 23),
|
||||
1968: (MAY, 11),
|
||||
1969: (MAY, 30),
|
||||
1970: (MAY, 19),
|
||||
1971: (MAY, 9),
|
||||
1972: (MAY, 27),
|
||||
1973: (MAY, 17),
|
||||
1974: (MAY, 6),
|
||||
1975: (MAY, 25),
|
||||
1976: (MAY, 13),
|
||||
1977: (JUN, 1),
|
||||
1978: (MAY, 21),
|
||||
1979: (MAY, 10),
|
||||
1980: (MAY, 28),
|
||||
1981: (MAY, 18),
|
||||
1982: (MAY, 8),
|
||||
1983: (MAY, 27),
|
||||
1984: (MAY, 15),
|
||||
1985: (JUN, 3),
|
||||
1986: (MAY, 23),
|
||||
1987: (MAY, 12),
|
||||
1988: (MAY, 30),
|
||||
1989: (MAY, 19),
|
||||
1990: (MAY, 9),
|
||||
1991: (MAY, 28),
|
||||
1992: (MAY, 17),
|
||||
1993: (JUN, 4),
|
||||
1994: (MAY, 25),
|
||||
1995: (MAY, 14),
|
||||
1996: (MAY, 31),
|
||||
1997: (MAY, 21),
|
||||
1998: (MAY, 10),
|
||||
1999: (MAY, 29),
|
||||
2000: (MAY, 18),
|
||||
2001: (MAY, 7),
|
||||
2002: (MAY, 26),
|
||||
2003: (MAY, 15),
|
||||
2004: (JUN, 2),
|
||||
2005: (MAY, 22),
|
||||
2006: (MAY, 12),
|
||||
2007: (MAY, 31),
|
||||
2008: (MAY, 19),
|
||||
2009: (MAY, 9),
|
||||
2010: (MAY, 28),
|
||||
2011: (MAY, 17),
|
||||
2012: (MAY, 5),
|
||||
2013: (MAY, 24),
|
||||
2014: (MAY, 13),
|
||||
2015: (JUN, 1),
|
||||
2016: (MAY, 21),
|
||||
2017: (MAY, 10),
|
||||
2018: (MAY, 29),
|
||||
2019: (MAY, 19),
|
||||
2020: (MAY, 7),
|
||||
2021: (MAY, 26),
|
||||
2022: (MAY, 15),
|
||||
2023: (JUN, 2),
|
||||
2024: (MAY, 22),
|
||||
2025: (MAY, 12),
|
||||
2026: (MAY, 31),
|
||||
2027: (MAY, 20),
|
||||
2028: (MAY, 9),
|
||||
2029: (MAY, 27),
|
||||
2030: (MAY, 16),
|
||||
2031: (JUN, 4),
|
||||
2032: (MAY, 23),
|
||||
2033: (MAY, 13),
|
||||
2034: (JUN, 1),
|
||||
2035: (MAY, 22),
|
||||
2036: (MAY, 10),
|
||||
2037: (MAY, 29),
|
||||
2038: (MAY, 18),
|
||||
2039: (MAY, 7),
|
||||
2040: (MAY, 25),
|
||||
2041: (MAY, 14),
|
||||
2042: (JUN, 2),
|
||||
2043: (MAY, 23),
|
||||
2044: (MAY, 12),
|
||||
2045: (MAY, 31),
|
||||
2046: (MAY, 20),
|
||||
2047: (MAY, 9),
|
||||
2048: (MAY, 27),
|
||||
2049: (MAY, 16),
|
||||
2050: (JUN, 4),
|
||||
2051: (MAY, 24),
|
||||
2052: (MAY, 13),
|
||||
2053: (JUN, 1),
|
||||
2054: (MAY, 22),
|
||||
2055: (MAY, 11),
|
||||
2056: (MAY, 29),
|
||||
2057: (MAY, 18),
|
||||
2058: (MAY, 7),
|
||||
2059: (MAY, 26),
|
||||
2060: (MAY, 14),
|
||||
2061: (JUN, 2),
|
||||
2062: (MAY, 23),
|
||||
2063: (MAY, 12),
|
||||
2064: (MAY, 30),
|
||||
2065: (MAY, 19),
|
||||
2066: (MAY, 8),
|
||||
2067: (MAY, 27),
|
||||
2068: (MAY, 16),
|
||||
2069: (MAY, 5),
|
||||
2070: (MAY, 24),
|
||||
2071: (MAY, 14),
|
||||
2072: (JUN, 1),
|
||||
2073: (MAY, 21),
|
||||
2074: (MAY, 10),
|
||||
2075: (MAY, 29),
|
||||
2076: (MAY, 17),
|
||||
2077: (MAY, 7),
|
||||
2078: (MAY, 26),
|
||||
2079: (MAY, 15),
|
||||
2080: (JUN, 2),
|
||||
2081: (MAY, 23),
|
||||
2082: (MAY, 12),
|
||||
2083: (MAY, 31),
|
||||
2084: (MAY, 19),
|
||||
2085: (MAY, 8),
|
||||
2086: (MAY, 27),
|
||||
2087: (MAY, 17),
|
||||
2088: (MAY, 5),
|
||||
2089: (MAY, 24),
|
||||
2090: (MAY, 14),
|
||||
2091: (JUN, 1),
|
||||
2092: (MAY, 20),
|
||||
2093: (MAY, 10),
|
||||
2094: (MAY, 28),
|
||||
2095: (MAY, 18),
|
||||
2096: (MAY, 7),
|
||||
2097: (MAY, 26),
|
||||
2098: (MAY, 15),
|
||||
2099: (JUN, 3),
|
||||
2100: (MAY, 23),
|
||||
}
|
||||
|
||||
VESAK_MAY_DATES = {
|
||||
1901: (MAY, 3),
|
||||
1902: (MAY, 22),
|
||||
1903: (MAY, 11),
|
||||
1904: (MAY, 29),
|
||||
1905: (MAY, 18),
|
||||
1906: (MAY, 8),
|
||||
1907: (MAY, 26),
|
||||
1908: (MAY, 14),
|
||||
1909: (MAY, 4),
|
||||
1910: (MAY, 23),
|
||||
1911: (MAY, 13),
|
||||
1912: (MAY, 1),
|
||||
1913: (MAY, 20),
|
||||
1914: (MAY, 9),
|
||||
1915: (MAY, 28),
|
||||
1916: (MAY, 16),
|
||||
1917: (MAY, 5),
|
||||
1918: (MAY, 24),
|
||||
1919: (MAY, 14),
|
||||
1920: (MAY, 3),
|
||||
1921: (MAY, 22),
|
||||
1922: (MAY, 11),
|
||||
1923: (MAY, 30),
|
||||
1924: (MAY, 18),
|
||||
1925: (MAY, 7),
|
||||
1926: (MAY, 26),
|
||||
1927: (MAY, 15),
|
||||
1928: (MAY, 4),
|
||||
1929: (MAY, 23),
|
||||
1930: (MAY, 13),
|
||||
1931: (MAY, 2),
|
||||
1932: (MAY, 20),
|
||||
1933: (MAY, 9),
|
||||
1934: (MAY, 27),
|
||||
1935: (MAY, 17),
|
||||
1936: (MAY, 5),
|
||||
1937: (MAY, 24),
|
||||
1938: (MAY, 14),
|
||||
1939: (MAY, 4),
|
||||
1940: (MAY, 21),
|
||||
1941: (MAY, 10),
|
||||
1942: (MAY, 29),
|
||||
1943: (MAY, 18),
|
||||
1944: (MAY, 7),
|
||||
1945: (MAY, 26),
|
||||
1946: (MAY, 15),
|
||||
1947: (MAY, 5),
|
||||
1948: (MAY, 23),
|
||||
1949: (MAY, 12),
|
||||
1950: (MAY, 1),
|
||||
1951: (MAY, 20),
|
||||
1952: (MAY, 8),
|
||||
1953: (MAY, 27),
|
||||
1954: (MAY, 17),
|
||||
1955: (MAY, 6),
|
||||
1956: (MAY, 24),
|
||||
1957: (MAY, 14),
|
||||
1958: (MAY, 3),
|
||||
1959: (MAY, 22),
|
||||
1960: (MAY, 10),
|
||||
1961: (MAY, 29),
|
||||
1962: (MAY, 18),
|
||||
1963: (MAY, 8),
|
||||
1964: (MAY, 26),
|
||||
1965: (MAY, 15),
|
||||
1966: (MAY, 5),
|
||||
1967: (MAY, 23),
|
||||
1968: (MAY, 11),
|
||||
1969: (MAY, 1),
|
||||
1970: (MAY, 19),
|
||||
1971: (MAY, 9),
|
||||
1972: (MAY, 27),
|
||||
1973: (MAY, 17),
|
||||
1974: (MAY, 6),
|
||||
1975: (MAY, 25),
|
||||
1976: (MAY, 13),
|
||||
1977: (MAY, 2),
|
||||
1978: (MAY, 21),
|
||||
1979: (MAY, 10),
|
||||
1980: (MAY, 28),
|
||||
1981: (MAY, 18),
|
||||
1982: (MAY, 8),
|
||||
1983: (MAY, 27),
|
||||
1984: (MAY, 15),
|
||||
1985: (MAY, 4),
|
||||
1986: (MAY, 23),
|
||||
1987: (MAY, 12),
|
||||
1988: (MAY, 30),
|
||||
1989: (MAY, 19),
|
||||
1990: (MAY, 9),
|
||||
1991: (MAY, 28),
|
||||
1992: (MAY, 17),
|
||||
1993: (MAY, 6),
|
||||
1994: (MAY, 25),
|
||||
1995: (MAY, 14),
|
||||
1996: (MAY, 2),
|
||||
1997: (MAY, 21),
|
||||
1998: (MAY, 10),
|
||||
1999: (MAY, 29),
|
||||
2000: (MAY, 18),
|
||||
2001: (MAY, 7),
|
||||
2002: (MAY, 26),
|
||||
2003: (MAY, 15),
|
||||
2004: (MAY, 3),
|
||||
2005: (MAY, 22),
|
||||
2006: (MAY, 12),
|
||||
2007: (MAY, 1),
|
||||
2008: (MAY, 19),
|
||||
2009: (MAY, 9),
|
||||
2010: (MAY, 28),
|
||||
2011: (MAY, 17),
|
||||
2012: (MAY, 5),
|
||||
2013: (MAY, 24),
|
||||
2014: (MAY, 13),
|
||||
2015: (MAY, 3),
|
||||
2016: (MAY, 21),
|
||||
2017: (MAY, 10),
|
||||
2018: (MAY, 29),
|
||||
2019: (MAY, 19),
|
||||
2020: (MAY, 7),
|
||||
2021: (MAY, 26),
|
||||
2022: (MAY, 15),
|
||||
2023: (MAY, 4),
|
||||
2024: (MAY, 22),
|
||||
2025: (MAY, 12),
|
||||
2026: (MAY, 1),
|
||||
2027: (MAY, 20),
|
||||
2028: (MAY, 9),
|
||||
2029: (MAY, 27),
|
||||
2030: (MAY, 16),
|
||||
2031: (MAY, 6),
|
||||
2032: (MAY, 23),
|
||||
2033: (MAY, 13),
|
||||
2034: (MAY, 3),
|
||||
2035: (MAY, 22),
|
||||
2036: (MAY, 10),
|
||||
2037: (MAY, 29),
|
||||
2038: (MAY, 18),
|
||||
2039: (MAY, 7),
|
||||
2040: (MAY, 25),
|
||||
2041: (MAY, 14),
|
||||
2042: (MAY, 4),
|
||||
2043: (MAY, 23),
|
||||
2044: (MAY, 12),
|
||||
2045: (MAY, 1),
|
||||
2046: (MAY, 20),
|
||||
2047: (MAY, 9),
|
||||
2048: (MAY, 27),
|
||||
2049: (MAY, 16),
|
||||
2050: (MAY, 5),
|
||||
2051: (MAY, 24),
|
||||
2052: (MAY, 13),
|
||||
2053: (MAY, 3),
|
||||
2054: (MAY, 22),
|
||||
2055: (MAY, 11),
|
||||
2056: (MAY, 29),
|
||||
2057: (MAY, 18),
|
||||
2058: (MAY, 7),
|
||||
2059: (MAY, 26),
|
||||
2060: (MAY, 14),
|
||||
2061: (MAY, 4),
|
||||
2062: (MAY, 23),
|
||||
2063: (MAY, 12),
|
||||
2064: (MAY, 1),
|
||||
2065: (MAY, 19),
|
||||
2066: (MAY, 8),
|
||||
2067: (MAY, 27),
|
||||
2068: (MAY, 16),
|
||||
2069: (MAY, 5),
|
||||
2070: (MAY, 24),
|
||||
2071: (MAY, 14),
|
||||
2072: (MAY, 2),
|
||||
2073: (MAY, 21),
|
||||
2074: (MAY, 10),
|
||||
2075: (MAY, 29),
|
||||
2076: (MAY, 17),
|
||||
2077: (MAY, 7),
|
||||
2078: (MAY, 26),
|
||||
2079: (MAY, 15),
|
||||
2080: (MAY, 4),
|
||||
2081: (MAY, 23),
|
||||
2082: (MAY, 12),
|
||||
2083: (MAY, 1),
|
||||
2084: (MAY, 19),
|
||||
2085: (MAY, 8),
|
||||
2086: (MAY, 27),
|
||||
2087: (MAY, 17),
|
||||
2088: (MAY, 5),
|
||||
2089: (MAY, 24),
|
||||
2090: (MAY, 14),
|
||||
2091: (MAY, 3),
|
||||
2092: (MAY, 20),
|
||||
2093: (MAY, 10),
|
||||
2094: (MAY, 28),
|
||||
2095: (MAY, 18),
|
||||
2096: (MAY, 7),
|
||||
2097: (MAY, 26),
|
||||
2098: (MAY, 15),
|
||||
2099: (MAY, 4),
|
||||
2100: (MAY, 23),
|
||||
}
|
||||
|
||||
def _get_holiday(self, holiday: str, year: int) -> tuple[Optional[date], bool]:
|
||||
estimated_dates = getattr(self, f"{holiday}_DATES", {})
|
||||
exact_dates = getattr(self, f"{holiday}_DATES_{_CustomCalendar.CUSTOM_ATTR_POSTFIX}", {})
|
||||
dt = exact_dates.get(year, estimated_dates.get(year, ()))
|
||||
return date(year, *dt) if dt else None, year not in exact_dates
|
||||
|
||||
def vesak_date(self, year: int) -> tuple[Optional[date], bool]:
|
||||
return self._get_holiday(VESAK, year)
|
||||
|
||||
def vesak_may_date(self, year: int) -> tuple[Optional[date], bool]:
|
||||
return self._get_holiday(VESAK_MAY, year)
|
||||
|
||||
|
||||
class _CustomBuddhistHolidays(_CustomCalendar, _BuddhistLunisolar):
|
||||
pass
|
||||
1354
.venv/lib/python3.12/site-packages/holidays/calendars/chinese.py
Normal file
1354
.venv/lib/python3.12/site-packages/holidays/calendars/chinese.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,34 @@
|
||||
# holidays
|
||||
# --------
|
||||
# A fast, efficient Python library for generating country, province and state
|
||||
# specific sets of holidays on the fly. It aims to make determining whether a
|
||||
# specific date is a holiday as fast and flexible as possible.
|
||||
#
|
||||
# Authors: Vacanza Team and individual contributors (see CONTRIBUTORS file)
|
||||
# dr-prodigy <dr.prodigy.github@gmail.com> (c) 2017-2023
|
||||
# ryanss <ryanssdev@icloud.com> (c) 2014-2017
|
||||
# Website: https://github.com/vacanza/holidays
|
||||
# License: MIT (see LICENSE file)
|
||||
|
||||
|
||||
class _CustomCalendarType(type):
|
||||
"""Helper class for simple calendar customization.
|
||||
|
||||
Renames child class public attributes keeping the original data under a new
|
||||
name with a `CUSTOM_ATTR_POSTFIX` postfix.
|
||||
|
||||
Allows for better readability of customized lunisolar calendar dates.
|
||||
"""
|
||||
|
||||
CUSTOM_ATTR_POSTFIX = "CUSTOM_CALENDAR"
|
||||
|
||||
def __new__(cls, name, bases, namespace):
|
||||
for attr in (key for key in tuple(namespace.keys()) if key[0] != "_"):
|
||||
namespace[f"{attr}_{_CustomCalendar.CUSTOM_ATTR_POSTFIX}"] = namespace[attr]
|
||||
del namespace[attr]
|
||||
|
||||
return super().__new__(cls, name, bases, namespace)
|
||||
|
||||
|
||||
class _CustomCalendar(metaclass=_CustomCalendarType):
|
||||
pass
|
||||
@ -0,0 +1,95 @@
|
||||
# holidays
|
||||
# --------
|
||||
# A fast, efficient Python library for generating country, province and state
|
||||
# specific sets of holidays on the fly. It aims to make determining whether a
|
||||
# specific date is a holiday as fast and flexible as possible.
|
||||
#
|
||||
# Authors: Vacanza Team and individual contributors (see CONTRIBUTORS file)
|
||||
# dr-prodigy <dr.prodigy.github@gmail.com> (c) 2017-2023
|
||||
# ryanss <ryanssdev@icloud.com> (c) 2014-2017
|
||||
# Website: https://github.com/vacanza/holidays
|
||||
# License: MIT (see LICENSE file)
|
||||
|
||||
from datetime import date
|
||||
|
||||
GREGORIAN_CALENDAR = "GREGORIAN_CALENDAR"
|
||||
|
||||
MON, TUE, WED, THU, FRI, SAT, SUN = range(7)
|
||||
WEEKEND = (SAT, SUN)
|
||||
|
||||
JAN, FEB, MAR, APR, MAY, JUN, JUL, AUG, SEP, OCT, NOV, DEC = range(1, 13)
|
||||
|
||||
DAYS = {str(d) for d in range(1, 32)}
|
||||
MONTHS = {
|
||||
m: i
|
||||
for i, m in enumerate(
|
||||
("jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec"), 1
|
||||
)
|
||||
}
|
||||
WEEKDAYS = {w: i for i, w in enumerate(("mon", "tue", "wed", "thu", "fri", "sat", "sun"))}
|
||||
|
||||
|
||||
# Holiday names.
|
||||
CHRISTMAS = "christmas"
|
||||
WINTER_SOLSTICE = "winter_solstice"
|
||||
|
||||
|
||||
def _timedelta(dt: date, days: int = 0) -> date:
|
||||
"""
|
||||
Return date that is `days` days after (days > 0) or before (days < 0) specified date.
|
||||
"""
|
||||
|
||||
return date.fromordinal(dt.toordinal() + days)
|
||||
|
||||
|
||||
def _get_nth_weekday_from(n: int, weekday: int, from_dt: date) -> date:
|
||||
"""
|
||||
Return date of a n-th weekday before a specific date
|
||||
if n is negative.
|
||||
Return date of n-th weekday after (including) a specific date
|
||||
if n is positive.
|
||||
Examples: 1st Monday, 2nd Saturday, etc).
|
||||
"""
|
||||
|
||||
return _timedelta(
|
||||
from_dt,
|
||||
(
|
||||
(n - 1) * 7 + (weekday - from_dt.weekday()) % 7
|
||||
if n > 0
|
||||
else (n + 1) * 7 - (from_dt.weekday() - weekday) % 7
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_nth_weekday_of_month(n: int, weekday: int, month: int, year: int) -> date:
|
||||
"""
|
||||
Return date of n-th weekday of month for a specific year
|
||||
(e.g. 1st Monday of Apr, 2nd Friday of June, etc).
|
||||
If n is negative the countdown starts at the end of month
|
||||
(i.e. -1 is last).
|
||||
"""
|
||||
|
||||
requested_year_month = (year, month)
|
||||
|
||||
if n < 0:
|
||||
month += 1
|
||||
if month > 12:
|
||||
month = 1
|
||||
year += 1
|
||||
start_date = _timedelta(date(year, month, 1), -1)
|
||||
else:
|
||||
start_date = date(year, month, 1)
|
||||
|
||||
dt = _get_nth_weekday_from(n, weekday, start_date)
|
||||
dt_year_month = (dt.year, dt.month)
|
||||
|
||||
if dt_year_month != requested_year_month:
|
||||
raise ValueError(f"{dt_year_month} returned for {requested_year_month}")
|
||||
|
||||
return dt
|
||||
|
||||
|
||||
def _get_all_sundays(year):
|
||||
first_sunday = _get_nth_weekday_of_month(1, SUN, JAN, year)
|
||||
for n in range(0, (date(year, DEC, 31) - first_sunday).days + 1, 7):
|
||||
yield _timedelta(first_sunday, n)
|
||||
1633
.venv/lib/python3.12/site-packages/holidays/calendars/hebrew.py
Normal file
1633
.venv/lib/python3.12/site-packages/holidays/calendars/hebrew.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user