Files
Time-Series-Analysis/.venv/lib/python3.12/site-packages/cmdstanpy/utils/filesystem.py
2025-07-30 18:53:50 +03:00

237 lines
6.9 KiB
Python

"""
Utilities for interacting with the filesystem on multiple platforms
"""
import contextlib
import os
import platform
import re
import shutil
import tempfile
from typing import Any, Iterator, List, Mapping, Optional, Tuple, Union
from cmdstanpy import _TMPDIR
from .json import write_stan_json
from .logging import get_logger
EXTENSION = '.exe' if platform.system() == 'Windows' else ''
def windows_short_path(path: str) -> str:
"""
Gets the short path name of a given long path.
http://stackoverflow.com/a/23598461/200291
On non-Windows platforms, returns the path
If (base)path does not exist, function raises RuntimeError
"""
if platform.system() != 'Windows':
return path
if os.path.isfile(path) or (
not os.path.isdir(path) and os.path.splitext(path)[1] != ''
):
base_path, file_name = os.path.split(path)
else:
base_path, file_name = path, ''
if not os.path.exists(base_path):
raise RuntimeError(
'Windows short path function needs a valid directory. '
'Base directory does not exist: "{}"'.format(base_path)
)
import ctypes
from ctypes import wintypes
# pylint: disable=invalid-name
_GetShortPathNameW = (
ctypes.windll.kernel32.GetShortPathNameW # type: ignore
)
_GetShortPathNameW.argtypes = [
wintypes.LPCWSTR,
wintypes.LPWSTR,
wintypes.DWORD,
]
_GetShortPathNameW.restype = wintypes.DWORD
output_buf_size = 0
while True:
output_buf = ctypes.create_unicode_buffer(output_buf_size)
needed = _GetShortPathNameW(base_path, output_buf, output_buf_size)
if output_buf_size >= needed:
short_base_path = output_buf.value
break
else:
output_buf_size = needed
short_path = (
os.path.join(short_base_path, file_name)
if file_name
else short_base_path
)
return short_path
def create_named_text_file(
dir: str, prefix: str, suffix: str, name_only: bool = False
) -> str:
"""
Create a named unique file, return filename.
Flag 'name_only' will create then delete the tmp file;
this lets us create filename args for commands which
disallow overwriting existing files (e.g., 'stansummary').
"""
fd = tempfile.NamedTemporaryFile(
mode='w+', prefix=prefix, suffix=suffix, dir=dir, delete=name_only
)
path = fd.name
fd.close()
return path
@contextlib.contextmanager
def pushd(new_dir: str) -> Iterator[None]:
"""Acts like pushd/popd."""
previous_dir = os.getcwd()
os.chdir(new_dir)
try:
yield
finally:
os.chdir(previous_dir)
def _temp_single_json(
data: Union[str, os.PathLike, Mapping[str, Any], None]
) -> Iterator[Optional[str]]:
"""Context manager for json files."""
if data is None:
yield None
return
if isinstance(data, (str, os.PathLike)):
yield str(data)
return
data_file = create_named_text_file(dir=_TMPDIR, prefix='', suffix='.json')
get_logger().debug('input tempfile: %s', data_file)
write_stan_json(data_file, data)
try:
yield data_file
finally:
with contextlib.suppress(PermissionError):
os.remove(data_file)
temp_single_json = contextlib.contextmanager(_temp_single_json)
def _temp_multiinput(
input: Union[str, os.PathLike, Mapping[str, Any], List[Any], None],
base: int = 1,
) -> Iterator[Optional[str]]:
if isinstance(input, list):
# most complicated case: list of inits
# for multiple chains, we need to create multiple files
# which look like somename_{i}.json and then pass somename.json
# to CmdStan
mother_file = create_named_text_file(
dir=_TMPDIR, prefix='', suffix='.json', name_only=True
)
new_files = [
os.path.splitext(mother_file)[0] + f'_{i+base}.json'
for i in range(len(input))
]
for init, file in zip(input, new_files):
if isinstance(init, dict):
write_stan_json(file, init)
elif isinstance(init, str):
shutil.copy(init, file)
else:
raise ValueError(
'A list of inits must contain dicts or strings, not'
+ str(type(init))
)
try:
yield mother_file
finally:
for file in new_files:
with contextlib.suppress(PermissionError):
os.remove(file)
else:
yield from _temp_single_json(input)
@contextlib.contextmanager
def temp_inits(
inits: Union[
str, os.PathLike, Mapping[str, Any], float, int, List[Any], None
],
*,
allow_multiple: bool = True,
id: int = 1,
) -> Iterator[Union[str, float, int, None]]:
if isinstance(inits, (float, int)):
yield inits
return
if allow_multiple:
yield from _temp_multiinput(inits, base=id)
else:
if isinstance(inits, list):
raise ValueError('Expected single initialization, got list')
yield from _temp_single_json(inits)
class SanitizedOrTmpFilePath:
"""
Context manager for tmpfiles, handles special characters in filepath.
"""
UNIXISH_PATTERN = re.compile(r"[\s~]")
WINDOWS_PATTERN = re.compile(r"\s")
@classmethod
def _has_special_chars(cls, file_path: str) -> bool:
if platform.system() == "Windows":
return bool(cls.WINDOWS_PATTERN.search(file_path))
return bool(cls.UNIXISH_PATTERN.search(file_path))
def __init__(self, file_path: str):
self._tmpdir = None
if self._has_special_chars(os.path.abspath(file_path)):
base_path, file_name = os.path.split(os.path.abspath(file_path))
os.makedirs(base_path, exist_ok=True)
try:
short_base_path = windows_short_path(base_path)
if os.path.exists(short_base_path):
file_path = os.path.join(short_base_path, file_name)
except RuntimeError:
pass
if self._has_special_chars(os.path.abspath(file_path)):
tmpdir = tempfile.mkdtemp()
if self._has_special_chars(tmpdir):
raise RuntimeError(
'Unable to generate temporary path without spaces or '
'special characters! \n Please move your stan file to a '
'location without spaces or special characters.'
)
_, path = tempfile.mkstemp(suffix='.stan', dir=tmpdir)
shutil.copy(file_path, path)
self._path = path
self._tmpdir = tmpdir
else:
self._path = file_path
def __enter__(self) -> Tuple[str, bool]:
return self._path, self._tmpdir is not None
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore
if self._tmpdir:
shutil.rmtree(self._tmpdir, ignore_errors=True)