some new features
This commit is contained in:
579
.venv/lib/python3.12/site-packages/cmdstanpy/compilation.py
Normal file
579
.venv/lib/python3.12/site-packages/cmdstanpy/compilation.py
Normal file
@ -0,0 +1,579 @@
|
||||
"""
|
||||
Makefile options for stanc and C++ compilers
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from cmdstanpy.utils import get_logger
|
||||
from cmdstanpy.utils.cmdstan import (
|
||||
EXTENSION,
|
||||
cmdstan_path,
|
||||
cmdstan_version,
|
||||
cmdstan_version_before,
|
||||
)
|
||||
from cmdstanpy.utils.command import do_command
|
||||
from cmdstanpy.utils.filesystem import SanitizedOrTmpFilePath
|
||||
|
||||
STANC_OPTS = [
|
||||
'O',
|
||||
'O0',
|
||||
'O1',
|
||||
'Oexperimental',
|
||||
'allow-undefined',
|
||||
'use-opencl',
|
||||
'warn-uninitialized',
|
||||
'include-paths',
|
||||
'name',
|
||||
'warn-pedantic',
|
||||
]
|
||||
|
||||
# TODO(2.0): remove
|
||||
STANC_DEPRECATED_OPTS = {
|
||||
'allow_undefined': 'allow-undefined',
|
||||
'include_paths': 'include-paths',
|
||||
}
|
||||
|
||||
STANC_IGNORE_OPTS = [
|
||||
'debug-lex',
|
||||
'debug-parse',
|
||||
'debug-ast',
|
||||
'debug-decorated-ast',
|
||||
'debug-generate-data',
|
||||
'debug-mir',
|
||||
'debug-mir-pretty',
|
||||
'debug-optimized-mir',
|
||||
'debug-optimized-mir-pretty',
|
||||
'debug-transformed-mir',
|
||||
'debug-transformed-mir-pretty',
|
||||
'dump-stan-math-signatures',
|
||||
'auto-format',
|
||||
'print-canonical',
|
||||
'print-cpp',
|
||||
'o',
|
||||
'help',
|
||||
'version',
|
||||
]
|
||||
|
||||
OptionalPath = Union[str, os.PathLike, None]
|
||||
|
||||
|
||||
# TODO(2.0): can remove add function and other logic
|
||||
class CompilerOptions:
|
||||
"""
|
||||
User-specified flags for stanc and C++ compiler.
|
||||
|
||||
Attributes:
|
||||
stanc_options - stanc compiler flags, options
|
||||
cpp_options - makefile options (NAME=value)
|
||||
user_header - path to a user .hpp file to include during compilation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
stanc_options: Optional[Dict[str, Any]] = None,
|
||||
cpp_options: Optional[Dict[str, Any]] = None,
|
||||
user_header: OptionalPath = None,
|
||||
) -> None:
|
||||
"""Initialize object."""
|
||||
self._stanc_options = stanc_options if stanc_options is not None else {}
|
||||
self._cpp_options = cpp_options if cpp_options is not None else {}
|
||||
self._user_header = str(user_header) if user_header is not None else ''
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return 'stanc_options={}, cpp_options={}'.format(
|
||||
self._stanc_options, self._cpp_options
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Overrides the default implementation"""
|
||||
if self.is_empty() and other is None: # equiv w/r/t compiler
|
||||
return True
|
||||
if not isinstance(other, CompilerOptions):
|
||||
return False
|
||||
return (
|
||||
self._stanc_options == other.stanc_options
|
||||
and self._cpp_options == other.cpp_options
|
||||
and self._user_header == other.user_header
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""True if no options specified."""
|
||||
return (
|
||||
self._stanc_options == {}
|
||||
and self._cpp_options == {}
|
||||
and self._user_header == ''
|
||||
)
|
||||
|
||||
@property
|
||||
def stanc_options(self) -> Dict[str, Union[bool, int, str, Iterable[str]]]:
|
||||
"""Stanc compiler options."""
|
||||
return self._stanc_options
|
||||
|
||||
@property
|
||||
def cpp_options(self) -> Dict[str, Union[bool, int]]:
|
||||
"""C++ compiler options."""
|
||||
return self._cpp_options
|
||||
|
||||
@property
|
||||
def user_header(self) -> str:
|
||||
"""user header."""
|
||||
return self._user_header
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Check compiler args.
|
||||
Raise ValueError if invalid options are found.
|
||||
"""
|
||||
self.validate_stanc_opts()
|
||||
self.validate_cpp_opts()
|
||||
self.validate_user_header()
|
||||
|
||||
def validate_stanc_opts(self) -> None:
|
||||
"""
|
||||
Check stanc compiler args and consistency between stanc and C++ options.
|
||||
Raise ValueError if bad config is found.
|
||||
"""
|
||||
# pylint: disable=no-member
|
||||
if self._stanc_options is None:
|
||||
return
|
||||
ignore = []
|
||||
paths = None
|
||||
has_o_flag = False
|
||||
|
||||
for deprecated, replacement in STANC_DEPRECATED_OPTS.items():
|
||||
if deprecated in self._stanc_options:
|
||||
if replacement:
|
||||
get_logger().warning(
|
||||
'compiler option "%s" is deprecated, use "%s" instead',
|
||||
deprecated,
|
||||
replacement,
|
||||
)
|
||||
self._stanc_options[replacement] = copy(
|
||||
self._stanc_options[deprecated]
|
||||
)
|
||||
del self._stanc_options[deprecated]
|
||||
else:
|
||||
get_logger().warning(
|
||||
'compiler option "%s" is deprecated and '
|
||||
'should not be used',
|
||||
deprecated,
|
||||
)
|
||||
for key, val in self._stanc_options.items():
|
||||
if key in STANC_IGNORE_OPTS:
|
||||
get_logger().info('ignoring compiler option: %s', key)
|
||||
ignore.append(key)
|
||||
elif key not in STANC_OPTS:
|
||||
raise ValueError(f'unknown stanc compiler option: {key}')
|
||||
elif key == 'include-paths':
|
||||
paths = val
|
||||
if isinstance(val, str):
|
||||
paths = val.split(',')
|
||||
elif not isinstance(val, list):
|
||||
raise ValueError(
|
||||
'Invalid include-paths, expecting list or '
|
||||
f'string, found type: {type(val)}.'
|
||||
)
|
||||
elif key == 'use-opencl':
|
||||
if self._cpp_options is None:
|
||||
self._cpp_options = {'STAN_OPENCL': 'TRUE'}
|
||||
else:
|
||||
self._cpp_options['STAN_OPENCL'] = 'TRUE'
|
||||
elif key.startswith('O'):
|
||||
if has_o_flag:
|
||||
get_logger().warning(
|
||||
'More than one of (O, O1, O2, Oexperimental)'
|
||||
'optimizations passed. Only the last one will'
|
||||
'be used'
|
||||
)
|
||||
else:
|
||||
has_o_flag = True
|
||||
|
||||
for opt in ignore:
|
||||
del self._stanc_options[opt]
|
||||
if paths is not None:
|
||||
bad_paths = [dir for dir in paths if not os.path.exists(dir)]
|
||||
if any(bad_paths):
|
||||
raise ValueError(
|
||||
'invalid include paths: {}'.format(', '.join(bad_paths))
|
||||
)
|
||||
|
||||
self._stanc_options['include-paths'] = [
|
||||
os.path.abspath(os.path.expanduser(path)) for path in paths
|
||||
]
|
||||
|
||||
def validate_cpp_opts(self) -> None:
|
||||
"""
|
||||
Check cpp compiler args.
|
||||
Raise ValueError if bad config is found.
|
||||
"""
|
||||
if self._cpp_options is None:
|
||||
return
|
||||
for key in ['OPENCL_DEVICE_ID', 'OPENCL_PLATFORM_ID']:
|
||||
if key in self._cpp_options:
|
||||
self._cpp_options['STAN_OPENCL'] = 'TRUE'
|
||||
val = self._cpp_options[key]
|
||||
if not isinstance(val, int) or val < 0:
|
||||
raise ValueError(
|
||||
f'{key} must be a non-negative integer value,'
|
||||
f' found {val}.'
|
||||
)
|
||||
|
||||
def validate_user_header(self) -> None:
|
||||
"""
|
||||
User header exists.
|
||||
Raise ValueError if bad config is found.
|
||||
"""
|
||||
if self._user_header != "":
|
||||
if not (
|
||||
os.path.exists(self._user_header)
|
||||
and os.path.isfile(self._user_header)
|
||||
):
|
||||
raise ValueError(
|
||||
f"User header file {self._user_header} cannot be found"
|
||||
)
|
||||
if self._user_header[-4:] != '.hpp':
|
||||
raise ValueError(
|
||||
f"Header file must end in .hpp, got {self._user_header}"
|
||||
)
|
||||
if "allow-undefined" not in self._stanc_options:
|
||||
self._stanc_options["allow-undefined"] = True
|
||||
# set full path
|
||||
self._user_header = os.path.abspath(self._user_header)
|
||||
|
||||
if ' ' in self._user_header:
|
||||
raise ValueError(
|
||||
"User header must be in a location with no spaces in path!"
|
||||
)
|
||||
|
||||
if (
|
||||
'USER_HEADER' in self._cpp_options
|
||||
and self._user_header != self._cpp_options['USER_HEADER']
|
||||
):
|
||||
raise ValueError(
|
||||
"Disagreement in user_header C++ options found!\n"
|
||||
f"{self._user_header}, {self._cpp_options['USER_HEADER']}"
|
||||
)
|
||||
|
||||
self._cpp_options['USER_HEADER'] = self._user_header
|
||||
|
||||
def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
|
||||
"""Adds options to existing set of compiler options."""
|
||||
if new_opts.stanc_options is not None:
|
||||
if self._stanc_options is None:
|
||||
self._stanc_options = new_opts.stanc_options
|
||||
else:
|
||||
for key, val in new_opts.stanc_options.items():
|
||||
if key == 'include-paths':
|
||||
if isinstance(val, Iterable) and not isinstance(
|
||||
val, str
|
||||
):
|
||||
for path in val:
|
||||
self.add_include_path(str(path))
|
||||
else:
|
||||
self.add_include_path(str(val))
|
||||
else:
|
||||
self._stanc_options[key] = val
|
||||
if new_opts.cpp_options is not None:
|
||||
for key, val in new_opts.cpp_options.items():
|
||||
self._cpp_options[key] = val
|
||||
if new_opts._user_header != '' and self._user_header == '':
|
||||
self._user_header = new_opts._user_header
|
||||
|
||||
def add_include_path(self, path: str) -> None:
|
||||
"""Adds include path to existing set of compiler options."""
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
if 'include-paths' not in self._stanc_options:
|
||||
self._stanc_options['include-paths'] = [path]
|
||||
elif path not in self._stanc_options['include-paths']:
|
||||
self._stanc_options['include-paths'].append(path)
|
||||
|
||||
def compose_stanc(self, filename_in_msg: Optional[str]) -> List[str]:
|
||||
opts = []
|
||||
|
||||
if filename_in_msg is not None:
|
||||
opts.append(f'--filename-in-msg={filename_in_msg}')
|
||||
|
||||
if self._stanc_options is not None and len(self._stanc_options) > 0:
|
||||
for key, val in self._stanc_options.items():
|
||||
if key == 'include-paths':
|
||||
opts.append(
|
||||
'--include-paths='
|
||||
+ ','.join(
|
||||
(
|
||||
Path(p).as_posix()
|
||||
for p in self._stanc_options['include-paths']
|
||||
)
|
||||
)
|
||||
)
|
||||
elif key == 'name':
|
||||
opts.append(f'--name={val}')
|
||||
else:
|
||||
opts.append(f'--{key}')
|
||||
return opts
|
||||
|
||||
def compose(self, filename_in_msg: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Format makefile options as list of strings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename_in_msg : str, optional
|
||||
filename to be displayed in stanc3 error messages
|
||||
(if different from actual filename on disk), by default None
|
||||
"""
|
||||
opts = [
|
||||
'STANCFLAGS+=' + flag.replace(" ", "\\ ")
|
||||
for flag in self.compose_stanc(filename_in_msg)
|
||||
]
|
||||
if self._cpp_options is not None and len(self._cpp_options) > 0:
|
||||
for key, val in self._cpp_options.items():
|
||||
opts.append(f'{key}={val}')
|
||||
return opts
|
||||
|
||||
|
||||
def src_info(
|
||||
stan_file: str, compiler_options: CompilerOptions
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get source info for Stan program file.
|
||||
|
||||
This function is used in the implementation of
|
||||
:meth:`CmdStanModel.src_info`, and should not be called directly.
|
||||
"""
|
||||
cmd = (
|
||||
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
|
||||
# handle include-paths, allow-undefined etc
|
||||
+ compiler_options.compose_stanc(None)
|
||||
+ ['--info', str(stan_file)]
|
||||
)
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if proc.returncode:
|
||||
raise ValueError(
|
||||
f"Failed to get source info for Stan model "
|
||||
f"'{stan_file}'. Console:\n{proc.stderr}"
|
||||
)
|
||||
result: Dict[str, Any] = json.loads(proc.stdout)
|
||||
return result
|
||||
|
||||
|
||||
def compile_stan_file(
|
||||
src: Union[str, Path],
|
||||
force: bool = False,
|
||||
stanc_options: Optional[Dict[str, Any]] = None,
|
||||
cpp_options: Optional[Dict[str, Any]] = None,
|
||||
user_header: OptionalPath = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compile the given Stan program file. Translates the Stan code to
|
||||
C++, then calls the C++ compiler.
|
||||
|
||||
By default, this function compares the timestamps on the source and
|
||||
executable files; if the executable is newer than the source file, it
|
||||
will not recompile the file, unless argument ``force`` is ``True``
|
||||
or unless the compiler options have been changed.
|
||||
|
||||
:param src: Path to Stan program file.
|
||||
|
||||
:param force: When ``True``, always compile, even if the executable file
|
||||
is newer than the source file. Used for Stan models which have
|
||||
``#include`` directives in order to force recompilation when changes
|
||||
are made to the included files.
|
||||
|
||||
:param stanc_options: Options for stanc compiler.
|
||||
:param cpp_options: Options for C++ compiler.
|
||||
:param user_header: A path to a header file to include during C++
|
||||
compilation.
|
||||
"""
|
||||
|
||||
src = Path(src).resolve()
|
||||
if not src.exists():
|
||||
raise ValueError(f'stan file does not exist: {src}')
|
||||
|
||||
compiler_options = CompilerOptions(
|
||||
stanc_options=stanc_options,
|
||||
cpp_options=cpp_options,
|
||||
user_header=user_header,
|
||||
)
|
||||
compiler_options.validate()
|
||||
|
||||
exe_target = src.with_suffix(EXTENSION)
|
||||
if exe_target.exists():
|
||||
exe_time = os.path.getmtime(exe_target)
|
||||
included_files = [src]
|
||||
included_files.extend(
|
||||
src_info(str(src), compiler_options).get('included_files', [])
|
||||
)
|
||||
out_of_date = any(
|
||||
os.path.getmtime(included_file) > exe_time
|
||||
for included_file in included_files
|
||||
)
|
||||
if not out_of_date and not force:
|
||||
get_logger().debug('found newer exe file, not recompiling')
|
||||
return str(exe_target)
|
||||
|
||||
compilation_failed = False
|
||||
# if target path has spaces or special characters, use a copy in a
|
||||
# temporary directory (GNU-Make constraint)
|
||||
with SanitizedOrTmpFilePath(str(src)) as (stan_file, is_copied):
|
||||
exe_file = os.path.splitext(stan_file)[0] + EXTENSION
|
||||
|
||||
hpp_file = os.path.splitext(exe_file)[0] + '.hpp'
|
||||
if os.path.exists(hpp_file):
|
||||
os.remove(hpp_file)
|
||||
if os.path.exists(exe_file):
|
||||
get_logger().debug('Removing %s', exe_file)
|
||||
os.remove(exe_file)
|
||||
|
||||
get_logger().info(
|
||||
'compiling stan file %s to exe file %s',
|
||||
stan_file,
|
||||
exe_target,
|
||||
)
|
||||
|
||||
make = os.getenv(
|
||||
'MAKE',
|
||||
'make' if platform.system() != 'Windows' else 'mingw32-make',
|
||||
)
|
||||
cmd = [make]
|
||||
cmd.extend(compiler_options.compose(filename_in_msg=src.name))
|
||||
cmd.append(Path(exe_file).as_posix())
|
||||
|
||||
sout = io.StringIO()
|
||||
try:
|
||||
do_command(cmd=cmd, cwd=cmdstan_path(), fd_out=sout)
|
||||
except RuntimeError as e:
|
||||
sout.write(f'\n{str(e)}\n')
|
||||
compilation_failed = True
|
||||
finally:
|
||||
console = sout.getvalue()
|
||||
|
||||
get_logger().debug('Console output:\n%s', console)
|
||||
if not compilation_failed:
|
||||
if is_copied:
|
||||
shutil.copy(exe_file, exe_target)
|
||||
get_logger().info('compiled model executable: %s', exe_target)
|
||||
if 'Warning' in console:
|
||||
lines = console.split('\n')
|
||||
warnings = [x for x in lines if x.startswith('Warning')]
|
||||
get_logger().warning(
|
||||
'Stan compiler has produced %d warnings:',
|
||||
len(warnings),
|
||||
)
|
||||
get_logger().warning(console)
|
||||
if compilation_failed:
|
||||
if 'PCH' in console or 'precompiled header' in console:
|
||||
get_logger().warning(
|
||||
"CmdStan's precompiled header (PCH) files "
|
||||
"may need to be rebuilt."
|
||||
"Please run cmdstanpy.rebuild_cmdstan().\n"
|
||||
"If the issue persists please open a bug report"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to compile Stan model '{src}'. " f"Console:\n{console}"
|
||||
)
|
||||
return str(exe_target)
|
||||
|
||||
|
||||
def format_stan_file(
|
||||
stan_file: Union[str, os.PathLike],
|
||||
*,
|
||||
overwrite_file: bool = False,
|
||||
canonicalize: Union[bool, str, Iterable[str]] = False,
|
||||
max_line_length: int = 78,
|
||||
backup: bool = True,
|
||||
stanc_options: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run stanc's auto-formatter on the model code. Either saves directly
|
||||
back to the file or prints for inspection
|
||||
|
||||
:param stan_file: Path to Stan program file.
|
||||
:param overwrite_file: If True, save the updated code to disk, rather
|
||||
than printing it. By default False
|
||||
:param canonicalize: Whether or not the compiler should 'canonicalize'
|
||||
the Stan model, removing things like deprecated syntax. Default is
|
||||
False. If True, all canonicalizations are run. If it is a list of
|
||||
strings, those options are passed to stanc (new in Stan 2.29)
|
||||
:param max_line_length: Set the wrapping point for the formatter. The
|
||||
default value is 78, which wraps most lines by the 80th character.
|
||||
:param backup: If True, create a stanfile.bak backup before
|
||||
writing to the file. Only disable this if you're sure you have other
|
||||
copies of the file or are using a version control system like Git.
|
||||
:param stanc_options: Additional options to pass to the stanc compiler.
|
||||
"""
|
||||
stan_file = Path(stan_file).resolve()
|
||||
|
||||
if not stan_file.exists():
|
||||
raise ValueError(f'File does not exist: {stan_file}')
|
||||
|
||||
try:
|
||||
cmd = (
|
||||
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
|
||||
# handle include-paths, allow-undefined etc
|
||||
+ CompilerOptions(stanc_options=stanc_options).compose_stanc(None)
|
||||
+ [str(stan_file)]
|
||||
)
|
||||
|
||||
if canonicalize:
|
||||
if cmdstan_version_before(2, 29):
|
||||
if isinstance(canonicalize, bool):
|
||||
cmd.append('--print-canonical')
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid arguments passed for current CmdStan"
|
||||
+ " version({})\n".format(
|
||||
cmdstan_version() or "Unknown"
|
||||
)
|
||||
+ "--canonicalize requires 2.29 or higher"
|
||||
)
|
||||
else:
|
||||
if isinstance(canonicalize, str):
|
||||
cmd.append('--canonicalize=' + canonicalize)
|
||||
elif isinstance(canonicalize, Iterable):
|
||||
cmd.append('--canonicalize=' + ','.join(canonicalize))
|
||||
else:
|
||||
cmd.append('--print-canonical')
|
||||
|
||||
# before 2.29, having both --print-canonical
|
||||
# and --auto-format printed twice
|
||||
if not (cmdstan_version_before(2, 29) and canonicalize):
|
||||
cmd.append('--auto-format')
|
||||
|
||||
if not cmdstan_version_before(2, 29):
|
||||
cmd.append(f'--max-line-length={max_line_length}')
|
||||
elif max_line_length != 78:
|
||||
raise ValueError(
|
||||
"Invalid arguments passed for current CmdStan version"
|
||||
+ " ({})\n".format(cmdstan_version() or "Unknown")
|
||||
+ "--max-line-length requires 2.29 or higher"
|
||||
)
|
||||
|
||||
out = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
if out.stderr:
|
||||
get_logger().warning(out.stderr)
|
||||
result = out.stdout
|
||||
if overwrite_file:
|
||||
if result:
|
||||
if backup:
|
||||
shutil.copyfile(
|
||||
stan_file,
|
||||
str(stan_file)
|
||||
+ '.bak-'
|
||||
+ datetime.now().strftime("%Y%m%d%H%M%S"),
|
||||
)
|
||||
stan_file.write_text(result)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
except (ValueError, RuntimeError) as e:
|
||||
raise RuntimeError("Stanc formatting failed") from e
|
||||
Reference in New Issue
Block a user