提交 eaf05be6 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Replace os.path with pathlib

And add type hints to configparser.py
上级 10f285a1
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# serve to show the default value. # serve to show the default value.
# If your extensions are in another directory, add it here. If the directory # If your extensions are in another directory, add it here. If the directory
# is relative to the documentation root, use os.path.abspath to make it # is relative to the documentation root, use Path.absolute to make it
# absolute, like shown here. # absolute, like shown here.
# sys.path.append(os.path.abspath('some/directory')) # sys.path.append(str(Path("some/directory").absolute()))
import os import os
import inspect
import sys import sys
import pytensor import pytensor
...@@ -236,11 +237,9 @@ def linkcode_resolve(domain, info): ...@@ -236,11 +237,9 @@ def linkcode_resolve(domain, info):
obj = sys.modules[info["module"]] obj = sys.modules[info["module"]]
for part in info["fullname"].split("."): for part in info["fullname"].split("."):
obj = getattr(obj, part) obj = getattr(obj, part)
import inspect
import os
fn = inspect.getsourcefile(obj) fn = Path(inspect.getsourcefile(obj))
fn = os.path.relpath(fn, start=os.path.dirname(pytensor.__file__)) fn = fn.relative_to(Path(__file__).parent)
source, lineno = inspect.getsourcelines(obj) source, lineno = inspect.getsourcelines(obj)
return fn, lineno, lineno + len(source) - 1 return fn, lineno, lineno + len(source) - 1
......
...@@ -76,10 +76,9 @@ visualize it with :py:func:`pytensor.printing.pydotprint` as follows: ...@@ -76,10 +76,9 @@ visualize it with :py:func:`pytensor.printing.pydotprint` as follows:
.. code:: python .. code:: python
from pytensor.printing import pydotprint from pytensor.printing import pydotprint
import os from pathlib import Path
if not os.path.exists('examples'): Path("examples").mkdir(exist_ok=True)
os.makedirs('examples')
pydotprint(predict, 'examples/mlp.png') pydotprint(predict, 'examples/mlp.png')
......
...@@ -259,7 +259,7 @@ PyTensor/BLAS speed test: ...@@ -259,7 +259,7 @@ PyTensor/BLAS speed test:
.. code-block:: bash .. code-block:: bash
python `python -c "import os, pytensor; print(os.path.dirname(pytensor.__file__))"`/misc/check_blas.py python $(python -c "import pathlib, pytensor; print(pathlib.Path(pytensor.__file__).parent / 'misc/check_blas.py')")
This will print a table with different versions of BLAS/numbers of This will print a table with different versions of BLAS/numbers of
threads on multiple CPUs. It will also print some PyTensor/NumPy threads on multiple CPUs. It will also print some PyTensor/NumPy
......
...@@ -11,7 +11,7 @@ dependencies: ...@@ -11,7 +11,7 @@ dependencies:
- compilers - compilers
- numpy>=1.17.0,<2 - numpy>=1.17.0,<2
- scipy>=0.14,<1.14.0 - scipy>=0.14,<1.14.0
- filelock - filelock>=3.15
- etuples - etuples
- logical-unification - logical-unification
- miniKanren - miniKanren
...@@ -27,7 +27,6 @@ dependencies: ...@@ -27,7 +27,6 @@ dependencies:
- coveralls - coveralls
- diff-cover - diff-cover
- mypy - mypy
- types-filelock
- types-setuptools - types-setuptools
- pytest - pytest
- pytest-cov - pytest-cov
......
...@@ -49,7 +49,7 @@ dependencies = [ ...@@ -49,7 +49,7 @@ dependencies = [
"setuptools>=59.0.0", "setuptools>=59.0.0",
"scipy>=0.14,<1.14", "scipy>=0.14,<1.14",
"numpy>=1.17.0,<2", "numpy>=1.17.0,<2",
"filelock", "filelock>=3.15",
"etuples", "etuples",
"logical-unification", "logical-unification",
"miniKanren", "miniKanren",
......
...@@ -23,9 +23,9 @@ __docformat__ = "restructuredtext en" ...@@ -23,9 +23,9 @@ __docformat__ = "restructuredtext en"
# Set a default logger. It is important to do this before importing some other # Set a default logger. It is important to do this before importing some other
# pytensor code, since this code may want to log some messages. # pytensor code, since this code may want to log some messages.
import logging import logging
import os
import sys import sys
from functools import singledispatch from functools import singledispatch
from pathlib import Path
from typing import Any, NoReturn, Optional from typing import Any, NoReturn, Optional
from pytensor.version import version as __version__ from pytensor.version import version as __version__
...@@ -52,10 +52,8 @@ def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler) ...@@ -52,10 +52,8 @@ def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler)
# Raise a meaningful warning/error if the pytensor directory is in the Python # Raise a meaningful warning/error if the pytensor directory is in the Python
# path. # path.
rpath = os.path.realpath(__path__[0]) rpath = Path(__file__).parent.resolve()
for p in sys.path: if any(rpath == Path(p).resolve() for p in sys.path):
if os.path.realpath(p) != rpath:
continue
raise RuntimeError("You have the pytensor directory in your Python path.") raise RuntimeError("You have the pytensor directory in your Python path.")
from pytensor.configdefaults import config from pytensor.configdefaults import config
......
...@@ -4,7 +4,6 @@ It is used by the "pytensor-cache" CLI tool, located in the /bin folder of the r ...@@ -4,7 +4,6 @@ It is used by the "pytensor-cache" CLI tool, located in the /bin folder of the r
""" """
import logging import logging
import os
import pickle import pickle
import shutil import shutil
from collections import Counter from collections import Counter
...@@ -33,12 +32,11 @@ def cleanup(): ...@@ -33,12 +32,11 @@ def cleanup():
If there is no key left for a compiled module, we delete the module. If there is no key left for a compiled module, we delete the module.
""" """
compiledir = config.compiledir for directory in config.compiledir.iterdir():
for directory in os.listdir(compiledir):
try: try:
filename = os.path.join(compiledir, directory, "key.pkl") filename = directory / "key.pkl"
# print file # print file
with open(filename, "rb") as file: with filename.open("rb") as file:
try: try:
keydata = pickle.load(file) keydata = pickle.load(file)
...@@ -79,7 +77,7 @@ def cleanup(): ...@@ -79,7 +77,7 @@ def cleanup():
"the directory containing it." "the directory containing it."
) )
if len(keydata.keys) == 0: if len(keydata.keys) == 0:
shutil.rmtree(os.path.join(compiledir, directory)) shutil.rmtree(directory)
except (EOFError, AttributeError): except (EOFError, AttributeError):
_logger.error( _logger.error(
...@@ -117,11 +115,11 @@ def print_compiledir_content(): ...@@ -117,11 +115,11 @@ def print_compiledir_content():
big_key_files = [] big_key_files = []
total_key_sizes = 0 total_key_sizes = 0
nb_keys = Counter() nb_keys = Counter()
for dir in os.listdir(compiledir): for dir in config.compiledir.iterdir():
filename = os.path.join(compiledir, dir, "key.pkl") filename = dir / "key.pkl"
if not os.path.exists(filename): if not filename.exists():
continue continue
with open(filename, "rb") as file: with filename.open("rb") as file:
try: try:
keydata = pickle.load(file) keydata = pickle.load(file)
ops = list({x for x in flatten(keydata.keys) if isinstance(x, Op)}) ops = list({x for x in flatten(keydata.keys) if isinstance(x, Op)})
...@@ -134,15 +132,11 @@ def print_compiledir_content(): ...@@ -134,15 +132,11 @@ def print_compiledir_content():
{x for x in flatten(keydata.keys) if isinstance(x, CType)} {x for x in flatten(keydata.keys) if isinstance(x, CType)}
) )
compile_start = compile_end = float("nan") compile_start = compile_end = float("nan")
for fn in os.listdir(os.path.join(compiledir, dir)): for fn in dir.iterdir():
if fn.startswith("mod.c"): if fn.name == "mod.c":
compile_start = os.path.getmtime( compile_start = fn.stat().st_mtime
os.path.join(compiledir, dir, fn) elif fn.suffix == ".so":
) compile_end = fn.stat().st_mtime
elif fn.endswith(".so"):
compile_end = os.path.getmtime(
os.path.join(compiledir, dir, fn)
)
compile_time = compile_end - compile_start compile_time = compile_end - compile_start
if len(ops) == 1: if len(ops) == 1:
table.append((dir, ops[0], types, compile_time)) table.append((dir, ops[0], types, compile_time))
...@@ -153,7 +147,7 @@ def print_compiledir_content(): ...@@ -153,7 +147,7 @@ def print_compiledir_content():
(dir, ops_to_str, types_to_str, compile_time) (dir, ops_to_str, types_to_str, compile_time)
) )
size = os.path.getsize(filename) size = filename.stat().st_size
total_key_sizes += size total_key_sizes += size
if size > max_key_file_size: if size > max_key_file_size:
big_key_files.append((dir, size, ops)) big_key_files.append((dir, size, ops))
...@@ -239,8 +233,8 @@ def basecompiledir_ls(): ...@@ -239,8 +233,8 @@ def basecompiledir_ls():
""" """
subdirs = [] subdirs = []
others = [] others = []
for f in os.listdir(config.base_compiledir): for f in config.base_compiledir.iterdir():
if os.path.isdir(os.path.join(config.base_compiledir, f)): if f.is_dir():
subdirs.append(f) subdirs.append(f)
else: else:
others.append(f) others.append(f)
......
...@@ -6,6 +6,7 @@ in the same compilation directory (which can cause crashes). ...@@ -6,6 +6,7 @@ in the same compilation directory (which can cause crashes).
import os import os
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
import filelock import filelock
...@@ -35,7 +36,7 @@ def force_unlock(lock_dir: os.PathLike): ...@@ -35,7 +36,7 @@ def force_unlock(lock_dir: os.PathLike):
Path to a directory that was locked with `lock_ctx`. Path to a directory that was locked with `lock_ctx`.
""" """
fl = filelock.FileLock(os.path.join(lock_dir, ".lock")) fl = filelock.FileLock(Path(lock_dir) / ".lock")
fl.release(force=True) fl.release(force=True)
dir_key = f"{lock_dir}-{os.getpid()}" dir_key = f"{lock_dir}-{os.getpid()}"
...@@ -72,7 +73,7 @@ def lock_ctx( ...@@ -72,7 +73,7 @@ def lock_ctx(
if dir_key not in local_mem._locks: if dir_key not in local_mem._locks:
local_mem._locks[dir_key] = True local_mem._locks[dir_key] = True
fl = filelock.FileLock(os.path.join(lock_dir, ".lock")) fl = filelock.FileLock(Path(lock_dir) / ".lock")
fl.acquire(timeout=timeout) fl.acquire(timeout=timeout)
try: try:
yield yield
......
...@@ -16,6 +16,7 @@ import sys ...@@ -16,6 +16,7 @@ import sys
import time import time
from collections import Counter, defaultdict from collections import Counter, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
...@@ -34,7 +35,7 @@ def extended_open(filename, mode="r"): ...@@ -34,7 +35,7 @@ def extended_open(filename, mode="r"):
elif filename == "<stderr>": elif filename == "<stderr>":
yield sys.stderr yield sys.stderr
else: else:
with open(filename, mode=mode) as f: with Path(filename).open(mode=mode) as f:
yield f yield f
......
...@@ -6,6 +6,7 @@ import re ...@@ -6,6 +6,7 @@ import re
import socket import socket
import sys import sys
import textwrap import textwrap
from pathlib import Path
import numpy as np import numpy as np
from setuptools._distutils.spawn import find_executable from setuptools._distutils.spawn import find_executable
...@@ -33,64 +34,60 @@ from pytensor.utils import ( ...@@ -33,64 +34,60 @@ from pytensor.utils import (
_logger = logging.getLogger("pytensor.configdefaults") _logger = logging.getLogger("pytensor.configdefaults")
def get_cuda_root(): def get_cuda_root() -> Path | None:
# We look for the cuda path since we need headers from there # We look for the cuda path since we need headers from there
v = os.getenv("CUDA_ROOT", "") if (v := os.getenv("CUDA_ROOT")) is not None:
if v: return Path(v)
return v if (v := os.getenv("CUDA_PATH")) is not None:
v = os.getenv("CUDA_PATH", "") return Path(v)
if v: if (s := os.getenv("PATH")) is None:
return v return Path()
s = os.getenv("PATH") for dir in s.split(os.pathsep):
if not s: if (Path(dir) / "nvcc").exists():
return "" return Path(dir).absolute().parent
for dir in s.split(os.path.pathsep): return None
if os.path.exists(os.path.join(dir, "nvcc")):
return os.path.dirname(os.path.abspath(dir))
return "" def default_cuda_include() -> Path | None:
def default_cuda_include():
if config.cuda__root: if config.cuda__root:
return os.path.join(config.cuda__root, "include") return config.cuda__root / "include"
return "" return None
def default_dnn_base_path(): def default_dnn_base_path() -> Path | None:
# We want to default to the cuda root if cudnn is installed there # We want to default to the cuda root if cudnn is installed there
root = config.cuda__root if config.cuda__root:
# The include doesn't change location between OS. # The include doesn't change location between OS.
if root and os.path.exists(os.path.join(root, "include", "cudnn.h")): if (config.cuda__root / "include/cudnn.h").exists():
return root return config.cuda__root
return "" return None
def default_dnn_inc_path(): def default_dnn_inc_path() -> Path | None:
if config.dnn__base_path != "": if config.dnn__base_path:
return os.path.join(config.dnn__base_path, "include") return config.dnn__base_path / "include"
return "" return None
def default_dnn_lib_path(): def default_dnn_lib_path() -> Path | None:
if config.dnn__base_path != "": if config.dnn__base_path:
if sys.platform == "win32": if sys.platform == "win32":
path = os.path.join(config.dnn__base_path, "lib", "x64") path = config.dnn__base_path / "lib/x64"
elif sys.platform == "darwin": elif sys.platform == "darwin":
path = os.path.join(config.dnn__base_path, "lib") path = config.dnn__base_path / "lib"
else: else:
# This is linux # This is linux
path = os.path.join(config.dnn__base_path, "lib64") path = config.dnn__base_path / "lib64"
return path return path
return "" return None
def default_dnn_bin_path(): def default_dnn_bin_path() -> Path | None:
if config.dnn__base_path != "": if config.dnn__base_path:
if sys.platform == "win32": if sys.platform == "win32":
return os.path.join(config.dnn__base_path, "bin") return config.dnn__base_path / "bin"
else:
return config.dnn__library_path return config.dnn__library_path
return "" return None
def _filter_mode(val): def _filter_mode(val):
...@@ -405,15 +402,11 @@ def add_compile_configvars(): ...@@ -405,15 +402,11 @@ def add_compile_configvars():
# Anaconda on Windows has mingw-w64 packages including GCC, but it may not be on PATH. # Anaconda on Windows has mingw-w64 packages including GCC, but it may not be on PATH.
if rc != 0: if rc != 0:
if sys.platform == "win32": if sys.platform == "win32":
mingw_w64_gcc = os.path.join( mingw_w64_gcc = Path(sys.executable).parent / "Library/mingw-w64/bin/g++"
os.path.dirname(sys.executable), "Library", "mingw-w64", "bin", "g++"
)
try: try:
rc = call_subprocess_Popen([mingw_w64_gcc, "-v"]) rc = call_subprocess_Popen([mingw_w64_gcc, "-v"])
if rc == 0: if rc == 0:
maybe_add_to_os_environ_pathlist( maybe_add_to_os_environ_pathlist("PATH", mingw_w64_gcc.parent)
"PATH", os.path.dirname(mingw_w64_gcc)
)
except OSError: except OSError:
rc = 1 rc = 1
if rc != 0: if rc != 0:
...@@ -1221,27 +1214,27 @@ def add_numba_configvars(): ...@@ -1221,27 +1214,27 @@ def add_numba_configvars():
) )
def _default_compiledirname(): def _default_compiledirname() -> str:
formatted = config.compiledir_format % _compiledir_format_dict formatted = config.compiledir_format % _compiledir_format_dict
safe = re.sub(r"[\(\)\s,]+", "_", formatted) safe = re.sub(r"[\(\)\s,]+", "_", formatted)
return safe return safe
def _filter_base_compiledir(path): def _filter_base_compiledir(path: Path) -> Path:
# Expand '~' in path # Expand '~' in path
return os.path.expanduser(str(path)) return path.expanduser()
def _filter_compiledir(path): def _filter_compiledir(path: Path) -> Path:
# Expand '~' in path # Expand '~' in path
path = os.path.expanduser(path) path = path.expanduser()
# Turn path into the 'real' path. This ensures that: # Turn path into the 'real' path. This ensures that:
# 1. There is no relative path, which would fail e.g. when trying to # 1. There is no relative path, which would fail e.g. when trying to
# import modules from the compile dir. # import modules from the compile dir.
# 2. The path is stable w.r.t. e.g. symlinks (which makes it easier # 2. The path is stable w.r.t. e.g. symlinks (which makes it easier
# to re-use compiled modules). # to re-use compiled modules).
path = os.path.realpath(path) path = path.resolve()
if os.access(path, os.F_OK): # Do it exist? if path.exists():
if not os.access(path, os.R_OK | os.W_OK | os.X_OK): if not os.access(path, os.R_OK | os.W_OK | os.X_OK):
# If it exist we need read, write and listing access # If it exist we need read, write and listing access
raise ValueError( raise ValueError(
...@@ -1250,7 +1243,8 @@ def _filter_compiledir(path): ...@@ -1250,7 +1243,8 @@ def _filter_compiledir(path):
) )
else: else:
try: try:
os.makedirs(path, 0o770) # read-write-execute for user and group # 0o770 = read-write-execute for user and group
path.mkdir(mode=0o770, parents=True, exist_ok=True)
except OSError as e: except OSError as e:
# Maybe another parallel execution of pytensor was trying to create # Maybe another parallel execution of pytensor was trying to create
# the same directory at the same time. # the same directory at the same time.
...@@ -1264,36 +1258,38 @@ def _filter_compiledir(path): ...@@ -1264,36 +1258,38 @@ def _filter_compiledir(path):
# os.system('touch') returned -1 for an unknown reason; the # os.system('touch') returned -1 for an unknown reason; the
# alternate approach here worked in all cases... it was weird. # alternate approach here worked in all cases... it was weird.
# No error should happen as we checked the permissions. # No error should happen as we checked the permissions.
init_file = os.path.join(path, "__init__.py") init_file = path / "__init__.py"
if not os.path.exists(init_file): if not init_file.exists():
try: try:
with open(init_file, "w"): with init_file.open("w"):
pass pass
except OSError as e: except OSError as e:
if os.path.exists(init_file): if init_file.exists():
pass # has already been created pass # has already been created
else: else:
e.args += (f"{path} exist? {os.path.exists(path)}",) e.args += (f"{path} exist? {path.exists()}",)
raise raise
return path return path
def _get_home_dir(): def _get_home_dir() -> Path:
""" """
Return location of the user's home directory. Return location of the user's home directory.
""" """
home = os.getenv("HOME") if (env_home := os.getenv("HOME")) is not None:
if home is None: return Path(env_home)
# This expanduser usually works on Windows (see discussion on
# theano-users, July 13 2010). # This usually works on Windows (see discussion on theano-users, July 13 2010).
home = os.path.expanduser("~") path_home = Path.home()
if home == "~": if str(path_home) != "~":
return path_home
# This might happen when expanduser fails. Although the cause of # This might happen when expanduser fails. Although the cause of
# failure is a mystery, it has been seen on some Windows system. # failure is a mystery, it has been seen on some Windows system.
home = os.getenv("USERPROFILE") windowsfail_home = os.getenv("USERPROFILE")
assert home is not None assert windowsfail_home is not None
return home return Path(windowsfail_home)
_compiledir_format_dict = { _compiledir_format_dict = {
...@@ -1309,8 +1305,8 @@ _compiledir_format_dict = { ...@@ -1309,8 +1305,8 @@ _compiledir_format_dict = {
} }
def _default_compiledir(): def _default_compiledir() -> Path:
return os.path.join(config.base_compiledir, _default_compiledirname()) return config.base_compiledir / _default_compiledirname()
def add_caching_dir_configvars(): def add_caching_dir_configvars():
...@@ -1343,9 +1339,9 @@ def add_caching_dir_configvars(): ...@@ -1343,9 +1339,9 @@ def add_caching_dir_configvars():
# part of the roaming part of the user profile. Instead we use the local part # part of the roaming part of the user profile. Instead we use the local part
# of the user profile, when available. # of the user profile, when available.
if sys.platform == "win32" and os.getenv("LOCALAPPDATA") is not None: if sys.platform == "win32" and os.getenv("LOCALAPPDATA") is not None:
default_base_compiledir = os.path.join(os.getenv("LOCALAPPDATA"), "PyTensor") default_base_compiledir = Path(os.getenv("LOCALAPPDATA")) / "PyTensor"
else: else:
default_base_compiledir = os.path.join(_get_home_dir(), ".pytensor") default_base_compiledir = _get_home_dir() / ".pytensor"
config.add( config.add(
"base_compiledir", "base_compiledir",
......
...@@ -13,6 +13,7 @@ from configparser import ( ...@@ -13,6 +13,7 @@ from configparser import (
) )
from functools import wraps from functools import wraps
from io import StringIO from io import StringIO
from pathlib import Path
from pytensor.utils import hash_from_code from pytensor.utils import hash_from_code
...@@ -22,7 +23,7 @@ _logger = logging.getLogger("pytensor.configparser") ...@@ -22,7 +23,7 @@ _logger = logging.getLogger("pytensor.configparser")
class PyTensorConfigWarning(Warning): class PyTensorConfigWarning(Warning):
@classmethod @classmethod
def warn(cls, message, stacklevel=0): def warn(cls, message: str, stacklevel: int = 0):
warnings.warn(message, cls, stacklevel=stacklevel + 3) warnings.warn(message, cls, stacklevel=stacklevel + 3)
...@@ -68,7 +69,123 @@ class _ChangeFlagsDecorator: ...@@ -68,7 +69,123 @@ class _ChangeFlagsDecorator:
class PyTensorConfigParser: class PyTensorConfigParser:
"""Object that holds configuration settings.""" """Object that holds configuration settings."""
def __init__(self, flags_dict: dict, pytensor_cfg, pytensor_raw_cfg): # add_basic_configvars
floatX: str
warn_float64: str
pickle_test_value: bool
cast_policy: str
deterministic: str
device: str
force_device: bool
conv__assert_shape: bool
print_global_stats: bool
assert_no_cpu_op: str
unpickle_function: bool
# add_compile_configvars
mode: str
cxx: str
linker: str
allow_gc: bool
optimizer: str
optimizer_verbose: bool
on_opt_error: str
nocleanup: bool
on_unused_import: str
gcc__cxxflags: str
cmodule__warn_no_version: bool
cmodule__remove_gxx_opt: bool
cmodule__compilation_warning: bool
cmodule__preload_cache: bool
cmodule__age_thresh_use: int
cmodule__debug: bool
compile__wait: int
compile__timeout: int
ctc__root: str
# add_tensor_configvars
tensor__cmp_sloppy: int
lib__amblibm: bool
tensor__insert_inplace_optimizer_validate_nb: int
# add_traceback_configvars
traceback__limit: int
traceback__compile_limit: int
# add_experimental_configvars
# add_error_and_warning_configvars
warn__ignore_bug_before: int
exception_verbosity: str
# add_testvalue_and_checking_configvars
print_test_value: bool
compute_test_value: str
compute_test_value_opt: str
check_input: bool
NanGuardMode__nan_is_error: bool
NanGuardMode__inf_is_error: bool
NanGuardMode__big_is_error: bool
NanGuardMode__action: str
DebugMode__patience: int
DebugMode__check_c: bool
DebugMode__check_py: bool
DebugMode__check_finite: bool
DebugMode__check_strides: int
DebugMode__warn_input_not_reused: bool
DebugMode__check_preallocated_output: str
DebugMode__check_preallocated_output_ndim: int
profiling__time_thunks: bool
profiling__n_apply: int
profiling__n_ops: int
profiling__output_line_width: int
profiling__min_memory_size: int
profiling__min_peak_memory: bool
profiling__destination: str
profiling__debugprint: bool
profiling__ignore_first_call: bool
on_shape_error: str
# add_multiprocessing_configvars
openmp: bool
openmp_elemwise_minsize: int
# add_optimizer_configvars
optimizer_excluding: str
optimizer_including: str
optimizer_requiring: str
optdb__position_cutoff: float
optdb__max_use_ratio: float
cycle_detection: str
check_stack_trace: str
metaopt__verbose: int
metaopt__optimizer_excluding: str
metaopt__optimizer_including: str
# add_vm_configvars
profile: bool
profile_optimizer: bool
profile_memory: bool
vm__lazy: bool | None
# add_deprecated_configvars
unittests__rseed: str
warn__round: bool
# add_scan_configvars
scan__allow_gc: bool
scan__allow_output_prealloc: bool
# add_numba_configvars
numba__vectorize_target: str
numba__fastmath: bool
numba__cache: bool
# add_caching_dir_configvars
compiledir_format: str
base_compiledir: Path
compiledir: Path
# add_blas_configvars
blas__ldflags: str
blas__check_openmp: bool
# add CUDA (?)
cuda__root: Path | None
dnn__base_path: Path | None
dnn__library_path: Path | None
def __init__(
self,
flags_dict: dict,
pytensor_cfg: ConfigParser,
pytensor_raw_cfg: RawConfigParser,
):
self._flags_dict = flags_dict self._flags_dict = flags_dict
self._pytensor_cfg = pytensor_cfg self._pytensor_cfg = pytensor_cfg
self._pytensor_raw_cfg = pytensor_raw_cfg self._pytensor_raw_cfg = pytensor_raw_cfg
...@@ -80,7 +197,7 @@ class PyTensorConfigParser: ...@@ -80,7 +197,7 @@ class PyTensorConfigParser:
self.config_print(buf=sio, print_doc=print_doc) self.config_print(buf=sio, print_doc=print_doc)
return sio.getvalue() return sio.getvalue()
def config_print(self, buf, print_doc=True): def config_print(self, buf, print_doc: bool = True):
for cv in self._config_var_dict.values(): for cv in self._config_var_dict.values():
print(cv, file=buf) print(cv, file=buf)
if print_doc: if print_doc:
...@@ -108,7 +225,9 @@ class PyTensorConfigParser: ...@@ -108,7 +225,9 @@ class PyTensorConfigParser:
) )
) )
def add(self, name, doc, configparam, in_c_key=True): def add(
self, name: str, doc: str, configparam: "ConfigParam", in_c_key: bool = True
):
"""Add a new variable to PyTensorConfigParser. """Add a new variable to PyTensorConfigParser.
This method performs some of the work of initializing `ConfigParam` instances. This method performs some of the work of initializing `ConfigParam` instances.
...@@ -168,7 +287,7 @@ class PyTensorConfigParser: ...@@ -168,7 +287,7 @@ class PyTensorConfigParser:
# the ConfigParam implements __get__/__set__, enabling us to create a property: # the ConfigParam implements __get__/__set__, enabling us to create a property:
setattr(self.__class__, name, configparam) setattr(self.__class__, name, configparam)
def fetch_val_for_key(self, key, delete_key=False): def fetch_val_for_key(self, key, delete_key: bool = False):
"""Return the overriding config value for a key. """Return the overriding config value for a key.
A successful search returns a string value. A successful search returns a string value.
An unsuccessful search raises a KeyError An unsuccessful search raises a KeyError
...@@ -260,9 +379,9 @@ class ConfigParam: ...@@ -260,9 +379,9 @@ class ConfigParam:
self._mutable = mutable self._mutable = mutable
self.is_default = True self.is_default = True
# set by PyTensorConfigParser.add: # set by PyTensorConfigParser.add:
self.name = None self.name: str = "unnamed"
self.doc = None self.doc: str = "undocumented"
self.in_c_key = None self.in_c_key: bool
# Note that we do not call `self.filter` on the default value: this # Note that we do not call `self.filter` on the default value: this
# will be done automatically in PyTensorConfigParser.add, potentially with a # will be done automatically in PyTensorConfigParser.add, potentially with a
...@@ -288,7 +407,7 @@ class ConfigParam: ...@@ -288,7 +407,7 @@ class ConfigParam:
return self._apply(value) return self._apply(value)
return value return value
def validate(self, value) -> bool | None: def validate(self, value) -> bool:
"""Validates that a parameter values falls into a supported set or range. """Validates that a parameter values falls into a supported set or range.
Raises Raises
...@@ -336,7 +455,7 @@ class ConfigParam: ...@@ -336,7 +455,7 @@ class ConfigParam:
class EnumStr(ConfigParam): class EnumStr(ConfigParam):
def __init__( def __init__(
self, default: str, options: Sequence[str], validate=None, mutable=True self, default: str, options: Sequence[str], validate=None, mutable: bool = True
): ):
"""Creates a str-based parameter that takes a predefined set of options. """Creates a str-based parameter that takes a predefined set of options.
...@@ -400,7 +519,7 @@ class BoolParam(TypedParam): ...@@ -400,7 +519,7 @@ class BoolParam(TypedParam):
True, 1, "true", "True", "1" True, 1, "true", "True", "1"
""" """
def __init__(self, default, validate=None, mutable=True): def __init__(self, default, validate=None, mutable: bool = True):
super().__init__(default, apply=self._apply, validate=validate, mutable=mutable) super().__init__(default, apply=self._apply, validate=validate, mutable=mutable)
def _apply(self, value): def _apply(self, value):
...@@ -454,7 +573,9 @@ class ContextsParam(ConfigParam): ...@@ -454,7 +573,9 @@ class ContextsParam(ConfigParam):
return val return val
def parse_config_string(config_string, issue_warnings=True): def parse_config_string(
config_string: str, issue_warnings: bool = True
) -> dict[str, str]:
""" """
Parses a config string (comma-separated key=value components) into a dict. Parses a config string (comma-separated key=value components) into a dict.
""" """
...@@ -480,7 +601,7 @@ def parse_config_string(config_string, issue_warnings=True): ...@@ -480,7 +601,7 @@ def parse_config_string(config_string, issue_warnings=True):
return config_dict return config_dict
def config_files_from_pytensorrc(): def config_files_from_pytensorrc() -> list[Path]:
""" """
PYTENSORRC can contain a colon-delimited list of config files, like PYTENSORRC can contain a colon-delimited list of config files, like
...@@ -489,17 +610,17 @@ def config_files_from_pytensorrc(): ...@@ -489,17 +610,17 @@ def config_files_from_pytensorrc():
In that case, definitions in files on the right (here, ``~/.pytensorrc``) In that case, definitions in files on the right (here, ``~/.pytensorrc``)
have precedence over those in files on the left. have precedence over those in files on the left.
""" """
rval = [ paths = [
os.path.expanduser(s) Path(s).expanduser()
for s in os.getenv("PYTENSORRC", "~/.pytensorrc").split(os.pathsep) for s in os.getenv("PYTENSORRC", "~/.pytensorrc").split(os.pathsep)
] ]
if os.getenv("PYTENSORRC") is None and sys.platform == "win32": if os.getenv("PYTENSORRC") is None and sys.platform == "win32":
# to don't need to change the filename and make it open easily # to don't need to change the filename and make it open easily
rval.append(os.path.expanduser("~/.pytensorrc.txt")) paths.append(Path("~/.pytensorrc.txt").expanduser())
return rval return paths
def _create_default_config(): def _create_default_config() -> PyTensorConfigParser:
# The PYTENSOR_FLAGS environment variable should be a list of comma-separated # The PYTENSOR_FLAGS environment variable should be a list of comma-separated
# [section__]option=value entries. If the section part is omitted, there should # [section__]option=value entries. If the section part is omitted, there should
# be only one section that contains the given option. # be only one section that contains the given option.
...@@ -509,7 +630,7 @@ def _create_default_config(): ...@@ -509,7 +630,7 @@ def _create_default_config():
config_files = config_files_from_pytensorrc() config_files = config_files_from_pytensorrc()
pytensor_cfg = ConfigParser( pytensor_cfg = ConfigParser(
{ {
"USER": os.getenv("USER", os.path.split(os.path.expanduser("~"))[-1]), "USER": os.getenv("USER", Path("~").expanduser().name),
"LSCRATCH": os.getenv("LSCRATCH", ""), "LSCRATCH": os.getenv("LSCRATCH", ""),
"TMPDIR": os.getenv("TMPDIR", ""), "TMPDIR": os.getenv("TMPDIR", ""),
"TEMP": os.getenv("TEMP", ""), "TEMP": os.getenv("TEMP", ""),
......
...@@ -4,13 +4,13 @@ Author: Christof Angermueller <cangermueller@gmail.com> ...@@ -4,13 +4,13 @@ Author: Christof Angermueller <cangermueller@gmail.com>
""" """
import json import json
import os
import shutil import shutil
from pathlib import Path
from pytensor.d3viz.formatting import PyDotFormatter from pytensor.d3viz.formatting import PyDotFormatter
__path__ = os.path.dirname(os.path.realpath(__file__)) __path__ = Path(__file__).parent
def replace_patterns(x, replace): def replace_patterns(x, replace):
...@@ -40,7 +40,7 @@ def safe_json(obj): ...@@ -40,7 +40,7 @@ def safe_json(obj):
return json.dumps(obj).replace("<", "\\u003c") return json.dumps(obj).replace("<", "\\u003c")
def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): def d3viz(fct, outfile: Path | str, copy_deps: bool = True, *args, **kwargs):
"""Create HTML file with dynamic visualizing of an PyTensor function graph. """Create HTML file with dynamic visualizing of an PyTensor function graph.
In the HTML file, the whole graph or single nodes can be moved by drag and In the HTML file, the whole graph or single nodes can be moved by drag and
...@@ -59,7 +59,7 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): ...@@ -59,7 +59,7 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
---------- ----------
fct : pytensor.compile.function.types.Function fct : pytensor.compile.function.types.Function
A compiled PyTensor function, variable, apply or a list of variables. A compiled PyTensor function, variable, apply or a list of variables.
outfile : str outfile : Path | str
Path to output HTML file. Path to output HTML file.
copy_deps : bool, optional copy_deps : bool, optional
Copy javascript and CSS dependencies to output directory. Copy javascript and CSS dependencies to output directory.
...@@ -78,37 +78,34 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): ...@@ -78,37 +78,34 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
dot_graph = dot_graph.decode("utf8") dot_graph = dot_graph.decode("utf8")
# Create output directory if not existing # Create output directory if not existing
outdir = os.path.dirname(outfile) outdir = Path(outfile).parent
if outdir != "" and not os.path.exists(outdir): outdir.mkdir(exist_ok=True)
os.makedirs(outdir)
# Read template HTML file # Read template HTML file
template_file = os.path.join(__path__, "html", "template.html") template_file = __path__ / "html/template.html"
with open(template_file) as f: template = template_file.read_text(encoding="utf-8")
template = f.read()
# Copy dependencies to output directory # Copy dependencies to output directory
src_deps = __path__ src_deps = __path__
if copy_deps: if copy_deps:
dst_deps = "d3viz" dst_deps = outdir / "d3viz"
for d in ("js", "css"): for d in ("js", "css"):
dep = os.path.join(outdir, dst_deps, d) dep = dst_deps / d
if not os.path.exists(dep): if not dep.exists():
shutil.copytree(os.path.join(src_deps, d), dep) shutil.copytree(src_deps / d, dep)
else: else:
dst_deps = src_deps dst_deps = src_deps
# Replace patterns in template # Replace patterns in template
replace = { replace = {
"%% JS_DIR %%": os.path.join(dst_deps, "js"), "%% JS_DIR %%": dst_deps / "js",
"%% CSS_DIR %%": os.path.join(dst_deps, "css"), "%% CSS_DIR %%": dst_deps / "css",
"%% DOT_GRAPH %%": safe_json(dot_graph), "%% DOT_GRAPH %%": safe_json(dot_graph),
} }
html = replace_patterns(template, replace) html = replace_patterns(template, replace)
# Write HTML file # Write HTML file
with open(outfile, "w") as f: Path(outfile).write_text(html)
f.write(html)
def d3write(fct, path, *args, **kwargs): def d3write(fct, path, *args, **kwargs):
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
Author: Christof Angermueller <cangermueller@gmail.com> Author: Christof Angermueller <cangermueller@gmail.com>
""" """
import os
from functools import reduce from functools import reduce
from pathlib import Path
import numpy as np import numpy as np
...@@ -285,7 +285,7 @@ def var_tag(var): ...@@ -285,7 +285,7 @@ def var_tag(var):
path, line, _, src = tag.trace[0][-1] path, line, _, src = tag.trace[0][-1]
else: else:
path, line, _, src = tag.trace[0] path, line, _, src = tag.trace[0]
path = os.path.basename(path) path = Path(path).name
path = path.replace("<", "") path = path.replace("<", "")
path = path.replace(">", "") path = path.replace(">", "")
src = src.encode() src = src.encode()
......
...@@ -7,7 +7,6 @@ import atexit ...@@ -7,7 +7,6 @@ import atexit
import importlib import importlib
import logging import logging
import os import os
import pathlib
import pickle import pickle
import platform import platform
import re import re
...@@ -23,6 +22,7 @@ import warnings ...@@ -23,6 +22,7 @@ import warnings
from collections.abc import Callable from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from io import BytesIO, StringIO from io import BytesIO, StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Protocol, cast from typing import TYPE_CHECKING, Protocol, cast
import numpy as np import numpy as np
...@@ -688,7 +688,7 @@ class ModuleCache: ...@@ -688,7 +688,7 @@ class ModuleCache:
""" """
dirname: str = "" dirname: Path
""" """
The working directory that is managed by this interface. The working directory that is managed by this interface.
...@@ -725,8 +725,13 @@ class ModuleCache: ...@@ -725,8 +725,13 @@ class ModuleCache:
""" """
def __init__(self, dirname, check_for_broken_eq=True, do_refresh=True): def __init__(
self.dirname = dirname self,
dirname: Path | str,
check_for_broken_eq: bool = True,
do_refresh: bool = True,
):
self.dirname = Path(dirname)
self.module_from_name = dict(self.module_from_name) self.module_from_name = dict(self.module_from_name)
self.entry_from_key = dict(self.entry_from_key) self.entry_from_key = dict(self.entry_from_key)
self.module_hash_to_key_data = dict(self.module_hash_to_key_data) self.module_hash_to_key_data = dict(self.module_hash_to_key_data)
...@@ -1637,12 +1642,12 @@ def _rmtree( ...@@ -1637,12 +1642,12 @@ def _rmtree(
_module_cache: ModuleCache | None = None _module_cache: ModuleCache | None = None
def get_module_cache(dirname: str, init_args=None) -> ModuleCache: def get_module_cache(dirname: Path | str, init_args=None) -> ModuleCache:
"""Create a new module_cache. """Create a new module_cache.
Parameters Parameters
---------- ----------
dirname dirname : Path | str
The name of the directory used by the cache. The name of the directory used by the cache.
init_args init_args
Keyword arguments passed to the `ModuleCache` constructor. Keyword arguments passed to the `ModuleCache` constructor.
...@@ -2753,7 +2758,7 @@ def default_blas_ldflags(): ...@@ -2753,7 +2758,7 @@ def default_blas_ldflags():
return [] return []
maybe_lib_dirs = [ maybe_lib_dirs = [
[pathlib.Path(p).resolve() for p in line[len("libraries: =") :].split(":")] [Path(p).resolve() for p in line[len("libraries: =") :].split(":")]
for line in stdout.decode(sys.getdefaultencoding()).splitlines() for line in stdout.decode(sys.getdefaultencoding()).splitlines()
if line.startswith("libraries: =") if line.startswith("libraries: =")
] ]
...@@ -2828,9 +2833,9 @@ def default_blas_ldflags(): ...@@ -2828,9 +2833,9 @@ def default_blas_ldflags():
all_libs = [ all_libs = [
l l
for path in [ for path in [
pathlib.Path(library_dir) Path(library_dir)
for library_dir in searched_library_dirs for library_dir in searched_library_dirs
if pathlib.Path(library_dir).exists() if Path(library_dir).exists()
] ]
for l in path.iterdir() for l in path.iterdir()
if l.suffix in {".so", ".dll", ".dylib"} if l.suffix in {".so", ".dll", ".dylib"}
......
import errno import errno
import os
import sys import sys
from pathlib import Path
from pytensor.compile.compilelock import lock_ctx from pytensor.compile.compilelock import lock_ctx
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -9,8 +9,7 @@ from pytensor.link.c import cmodule ...@@ -9,8 +9,7 @@ from pytensor.link.c import cmodule
# TODO These two lines may be removed in the future, when we are 100% sure # TODO These two lines may be removed in the future, when we are 100% sure
# no one has an old cutils_ext.so lying around anymore. # no one has an old cutils_ext.so lying around anymore.
if os.path.exists(os.path.join(config.compiledir, "cutils_ext.so")): (config.compiledir / "cutils_ext.so").unlink(missing_ok=True)
os.remove(os.path.join(config.compiledir, "cutils_ext.so"))
def compile_cutils(): def compile_cutils():
...@@ -68,13 +67,13 @@ def compile_cutils(): ...@@ -68,13 +67,13 @@ def compile_cutils():
} }
""" """
loc = os.path.join(config.compiledir, "cutils_ext") loc = config.compiledir / "cutils_ext"
if not os.path.exists(loc): if not loc.exists():
try: try:
os.mkdir(loc) loc.mkdir()
except OSError as e: except OSError as e:
assert e.errno == errno.EEXIST assert e.errno == errno.EEXIST
assert os.path.exists(loc), loc assert loc.exists(), loc
args = cmodule.GCC_compiler.compile_args(march_flags=False) args = cmodule.GCC_compiler.compile_args(march_flags=False)
cmodule.GCC_compiler.compile_str("cutils_ext", code, location=loc, preargs=args) cmodule.GCC_compiler.compile_str("cutils_ext", code, location=loc, preargs=args)
...@@ -87,17 +86,15 @@ try: ...@@ -87,17 +86,15 @@ try:
# for the same reason. Note that these 5 lines may seem redundant (they are # for the same reason. Note that these 5 lines may seem redundant (they are
# repeated in compile_str()) but if another cutils_ext does exist then it # repeated in compile_str()) but if another cutils_ext does exist then it
# will be imported and compile_str won't get called at all. # will be imported and compile_str won't get called at all.
sys.path.insert(0, config.compiledir) sys.path.insert(0, str(config.compiledir))
location = os.path.join(config.compiledir, "cutils_ext") location = config.compiledir / "cutils_ext"
if not os.path.exists(location): if not location.exists():
try: try:
os.mkdir(location) location.mkdir()
except OSError as e: except OSError as e:
assert e.errno == errno.EEXIST assert e.errno == errno.EEXIST
assert os.path.exists(location), location assert location.exists(), location
if not os.path.exists(os.path.join(location, "__init__.py")): (location / "__init__.py").touch(exist_ok=True)
with open(os.path.join(location, "__init__.py"), "w"):
pass
try: try:
from cutils_ext.cutils_ext import * # noqa from cutils_ext.cutils_ext import * # noqa
...@@ -115,5 +112,5 @@ try: ...@@ -115,5 +112,5 @@ try:
compile_cutils() compile_cutils()
from cutils_ext.cutils_ext import * # noqa from cutils_ext.cutils_ext import * # noqa
finally: finally:
if sys.path[0] == config.compiledir: if config.compiledir.resolve() == Path(sys.path[0]).resolve():
del sys.path[0] del sys.path[0]
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import sys import sys
import warnings import warnings
from importlib import reload from importlib import reload
from pathlib import Path
from types import ModuleType from types import ModuleType
import pytensor import pytensor
...@@ -21,14 +22,14 @@ lazylinker_ext: ModuleType | None = None ...@@ -21,14 +22,14 @@ lazylinker_ext: ModuleType | None = None
def try_import(): def try_import():
global lazylinker_ext global lazylinker_ext
sys.path[0:0] = [config.compiledir] sys.path[0:0] = [str(config.compiledir)]
import lazylinker_ext import lazylinker_ext
del sys.path[0] del sys.path[0]
def try_reload(): def try_reload():
sys.path[0:0] = [config.compiledir] sys.path[0:0] = [str(config.compiledir)]
reload(lazylinker_ext) reload(lazylinker_ext)
del sys.path[0] del sys.path[0]
...@@ -41,8 +42,8 @@ try: ...@@ -41,8 +42,8 @@ try:
# Note that these lines may seem redundant (they are repeated in # Note that these lines may seem redundant (they are repeated in
# compile_str()) but if another lazylinker_ext does exist then it will be # compile_str()) but if another lazylinker_ext does exist then it will be
# imported and compile_str won't get called at all. # imported and compile_str won't get called at all.
location = os.path.join(config.compiledir, "lazylinker_ext") location = config.compiledir / "lazylinker_ext"
if not os.path.exists(location): if not location.exists():
try: try:
# Try to make the location # Try to make the location
os.mkdir(location) os.mkdir(location)
...@@ -53,18 +54,18 @@ try: ...@@ -53,18 +54,18 @@ try:
# are not holding the lock right now, so we could race # are not holding the lock right now, so we could race
# another process and get error 17 if we lose the race # another process and get error 17 if we lose the race
assert e.errno == errno.EEXIST assert e.errno == errno.EEXIST
assert os.path.isdir(location) assert location.is_dir()
init_file = os.path.join(location, "__init__.py") init_file = location / "__init__.py"
if not os.path.exists(init_file): if not init_file.exists():
try: try:
with open(init_file, "w"): with open(init_file, "w"):
pass pass
except OSError as e: except OSError as e:
if os.path.exists(init_file): if init_file.exists():
pass # has already been created pass # has already been created
else: else:
e.args += (f"{location} exist? {os.path.exists(location)}",) e.args += (f"{location} exist? {location.exists()}",)
raise raise
_need_reload = False _need_reload = False
...@@ -109,10 +110,8 @@ except ImportError: ...@@ -109,10 +110,8 @@ except ImportError:
raise raise
_logger.info("Compiling new CVM") _logger.info("Compiling new CVM")
dirname = "lazylinker_ext" dirname = "lazylinker_ext"
cfile = os.path.join( cfile = Path(pytensor.__path__[0]) / "link/c/c_code/lazylinker_c.c"
pytensor.__path__[0], "link", "c", "c_code", "lazylinker_c.c" if not cfile.exists():
)
if not os.path.exists(cfile):
# This can happen in not normal case. We just # This can happen in not normal case. We just
# disable the c clinker. If we are here the user # disable the c clinker. If we are here the user
# didn't disable the compiler, so print a warning. # didn't disable the compiler, so print a warning.
...@@ -127,30 +126,28 @@ except ImportError: ...@@ -127,30 +126,28 @@ except ImportError:
) )
raise ImportError("The file lazylinker_c.c is not available.") raise ImportError("The file lazylinker_c.c is not available.")
with open(cfile) as f: code = cfile.read_text("utf-8")
code = f.read()
loc = os.path.join(config.compiledir, dirname) loc = config.compiledir / dirname
if not os.path.exists(loc): if not loc.exists():
try: try:
os.mkdir(loc) os.mkdir(loc)
except OSError as e: except OSError as e:
assert e.errno == errno.EEXIST assert e.errno == errno.EEXIST
assert os.path.exists(loc) assert loc.exists()
args = GCC_compiler.compile_args() args = GCC_compiler.compile_args()
GCC_compiler.compile_str(dirname, code, location=loc, preargs=args) GCC_compiler.compile_str(dirname, code, location=loc, preargs=args)
# Save version into the __init__.py file. # Save version into the __init__.py file.
init_py = os.path.join(loc, "__init__.py") init_py = loc / "__init__.py"
with open(init_py, "w") as f: init_py.write_text(f"_version = {version}\n")
f.write(f"_version = {version}\n")
# If we just compiled the module for the first time, then it was # If we just compiled the module for the first time, then it was
# imported at the same time: we need to make sure we do not # imported at the same time: we need to make sure we do not
# reload the now outdated __init__.pyc below. # reload the now outdated __init__.pyc below.
init_pyc = os.path.join(loc, "__init__.pyc") init_pyc = loc / "__init__.pyc"
if os.path.isfile(init_pyc): if init_pyc.is_file():
os.remove(init_pyc) os.remove(init_pyc)
try_import() try_import()
......
import inspect import inspect
import os
import re import re
import warnings import warnings
from collections.abc import Callable, Collection from collections.abc import Callable, Collection, Iterable
from pathlib import Path
from re import Pattern from re import Pattern
from typing import TYPE_CHECKING, Any, ClassVar, cast from typing import TYPE_CHECKING, Any, ClassVar, cast
...@@ -279,28 +279,32 @@ class ExternalCOp(COp): ...@@ -279,28 +279,32 @@ class ExternalCOp(COp):
_cop_num_outputs: int | None = None _cop_num_outputs: int | None = None
@classmethod @classmethod
def get_path(cls, f: str) -> str: def get_path(cls, f: Path) -> Path:
"""Convert a path relative to the location of the class file into an absolute path. """Convert a path relative to the location of the class file into an absolute path.
Paths that are already absolute are passed through unchanged. Paths that are already absolute are passed through unchanged.
""" """
if not os.path.isabs(f): if not f.is_absolute():
class_file = inspect.getfile(cls) class_file = inspect.getfile(cls)
class_dir = os.path.dirname(class_file) class_dir = Path(class_file).parent
f = os.path.realpath(os.path.join(class_dir, f)) f = (class_dir / f).resolve()
return f return f
def __init__(self, func_files: str | list[str], func_name: str | None = None): def __init__(
self,
func_files: str | Path | list[str] | list[Path],
func_name: str | None = None,
):
""" """
Sections are loaded from files in order with sections in later Sections are loaded from files in order with sections in later
files overriding sections in previous files. files overriding sections in previous files.
""" """
if not isinstance(func_files, list): if not isinstance(func_files, list):
self.func_files = [func_files] self.func_files = [Path(func_files)]
else: else:
self.func_files = func_files self.func_files = [Path(func_file) for func_file in func_files]
self.func_codes: list[str] = [] self.func_codes: list[str] = []
# Keep the original name. If we reload old pickle, we want to # Keep the original name. If we reload old pickle, we want to
...@@ -325,22 +329,20 @@ class ExternalCOp(COp): ...@@ -325,22 +329,20 @@ class ExternalCOp(COp):
"Cannot have an `op_code_cleanup` section and specify `func_name`" "Cannot have an `op_code_cleanup` section and specify `func_name`"
) )
def load_c_code(self, func_files: list[str]) -> None: def load_c_code(self, func_files: Iterable[Path]) -> None:
"""Loads the C code to perform the `Op`.""" """Loads the C code to perform the `Op`."""
func_files = [self.get_path(f) for f in func_files]
for func_file in func_files: for func_file in func_files:
with open(func_file) as f: func_file = self.get_path(func_file)
self.func_codes.append(f.read()) self.func_codes.append(func_file.read_text(encoding="utf-8"))
# If both the old section markers and the new section markers are # If both the old section markers and the new section markers are
# present, raise an error because we don't know which ones to follow. # present, raise an error because we don't know which ones to follow.
old_markers_present = False old_markers_present = any(
new_markers_present = False self.backward_re.search(code) for code in self.func_codes
for code in self.func_codes: )
if self.backward_re.search(code): new_markers_present = any(
old_markers_present = True self.section_re.search(code) for code in self.func_codes
if self.section_re.search(code): )
new_markers_present = True
if old_markers_present and new_markers_present: if old_markers_present and new_markers_present:
raise ValueError( raise ValueError(
...@@ -350,7 +352,7 @@ class ExternalCOp(COp): ...@@ -350,7 +352,7 @@ class ExternalCOp(COp):
"be used at the same time." "be used at the same time."
) )
for i, code in enumerate(self.func_codes): for func_file, code in zip(func_files, self.func_codes):
if self.backward_re.search(code): if self.backward_re.search(code):
# This is backward compat code that will go away in a while # This is backward compat code that will go away in a while
...@@ -371,7 +373,7 @@ class ExternalCOp(COp): ...@@ -371,7 +373,7 @@ class ExternalCOp(COp):
if split[0].strip() != "": if split[0].strip() != "":
raise ValueError( raise ValueError(
"Stray code before first #section " "Stray code before first #section "
f"statement (in file {func_files[i]}): {split[0]}" f"statement (in file {func_file}): {split[0]}"
) )
# Separate the code into the proper sections # Separate the code into the proper sections
...@@ -379,7 +381,7 @@ class ExternalCOp(COp): ...@@ -379,7 +381,7 @@ class ExternalCOp(COp):
while n < len(split): while n < len(split):
if split[n] not in self.SECTIONS: if split[n] not in self.SECTIONS:
raise ValueError( raise ValueError(
f"Unknown section type (in file {func_files[i]}): {split[n]}" f"Unknown section type (in file {func_file}): {split[n]}"
) )
if split[n] not in self.code_sections: if split[n] not in self.code_sections:
self.code_sections[split[n]] = "" self.code_sections[split[n]] = ""
...@@ -388,7 +390,7 @@ class ExternalCOp(COp): ...@@ -388,7 +390,7 @@ class ExternalCOp(COp):
else: else:
raise ValueError( raise ValueError(
f"No valid section marker was found in file {func_files[i]}" f"No valid section marker was found in file {func_file}"
) )
def __get_op_params(self) -> list[tuple[str, Any]]: def __get_op_params(self) -> list[tuple[str, Any]]:
......
...@@ -3,6 +3,7 @@ import subprocess ...@@ -3,6 +3,7 @@ import subprocess
import sys import sys
from locale import getpreferredencoding from locale import getpreferredencoding
from optparse import OptionParser from optparse import OptionParser
from pathlib import Path
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -25,7 +26,7 @@ parser.add_option( ...@@ -25,7 +26,7 @@ parser.add_option(
def runScript(N): def runScript(N):
script = "elemwise_time_test.py" script = "elemwise_time_test.py"
path = os.path.dirname(os.path.abspath(__file__)) path = Path(__file__).parent
proc = subprocess.Popen( proc = subprocess.Popen(
["python", script, "--script", "-N", str(N)], ["python", script, "--script", "-N", str(N)],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import hashlib import hashlib
import logging import logging
import os
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
...@@ -10,6 +9,7 @@ from contextlib import contextmanager ...@@ -10,6 +9,7 @@ from contextlib import contextmanager
from copy import copy from copy import copy
from functools import reduce, singledispatch from functools import reduce, singledispatch
from io import StringIO from io import StringIO
from pathlib import Path
from typing import Any, Literal, TextIO from typing import Any, Literal, TextIO
import numpy as np import numpy as np
...@@ -1201,7 +1201,7 @@ default_colorCodes = { ...@@ -1201,7 +1201,7 @@ default_colorCodes = {
def pydotprint( def pydotprint(
fct, fct,
outfile: str | None = None, outfile: Path | str | None = None,
compact: bool = True, compact: bool = True,
format: str = "png", format: str = "png",
with_ids: bool = False, with_ids: bool = False,
...@@ -1296,9 +1296,9 @@ def pydotprint( ...@@ -1296,9 +1296,9 @@ def pydotprint(
colorCodes = default_colorCodes colorCodes = default_colorCodes
if outfile is None: if outfile is None:
outfile = os.path.join( outfile = config.compiledir / f"pytensor.pydotprint.{config.device}.{format}"
config.compiledir, "pytensor.pydotprint." + config.device + "." + format elif isinstance(outfile, str):
) outfile = Path(outfile)
if isinstance(fct, Function): if isinstance(fct, Function):
profile = getattr(fct, "profile", None) profile = getattr(fct, "profile", None)
...@@ -1607,23 +1607,19 @@ def pydotprint( ...@@ -1607,23 +1607,19 @@ def pydotprint(
g.add_subgraph(c2) g.add_subgraph(c2)
g.add_subgraph(c3) g.add_subgraph(c3)
if not outfile.endswith("." + format): if outfile.suffix != f".{format}":
outfile += "." + format outfile = outfile.with_suffix(f".{format}")
if scan_graphs: if scan_graphs:
scan_ops = [(idx, x) for idx, x in enumerate(topo) if isinstance(x.op, Scan)] scan_ops = [(idx, x) for idx, x in enumerate(topo) if isinstance(x.op, Scan)]
path, fn = os.path.split(outfile)
basename = ".".join(fn.split(".")[:-1])
# Safe way of doing things .. a file name may contain multiple .
ext = fn[len(basename) :]
for idx, scan_op in scan_ops: for idx, scan_op in scan_ops:
# is there a chance that name is not defined? # is there a chance that name is not defined?
if hasattr(scan_op.op, "name"): if hasattr(scan_op.op, "name"):
new_name = basename + "_" + scan_op.op.name + "_" + str(idx) new_name = outfile.stem + "_" + scan_op.op.name + "_" + str(idx)
else: else:
new_name = basename + "_" + str(idx) new_name = outfile.stem + "_" + str(idx)
new_name = os.path.join(path, new_name + ext) new_name = outfile.with_stem(new_name)
if hasattr(scan_op.op, "_fn"): if hasattr(scan_op.op, "_fn"):
to_print = scan_op.op.fn to_print = scan_op.op.fn
else: else:
......
...@@ -4,8 +4,8 @@ r""" ...@@ -4,8 +4,8 @@ r"""
As SciPy is not always available, we treat them separately. As SciPy is not always available, we treat them separately.
""" """
import os
from functools import reduce from functools import reduce
from pathlib import Path
from textwrap import dedent from textwrap import dedent
import numpy as np import numpy as np
...@@ -47,6 +47,9 @@ from pytensor.scalar.basic import abs as scalar_abs ...@@ -47,6 +47,9 @@ from pytensor.scalar.basic import abs as scalar_abs
from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.loop import ScalarLoop
C_CODE_PATH = Path(__file__).parent / "c_code"
class Erf(UnaryScalarOp): class Erf(UnaryScalarOp):
nfunc_spec = ("scipy.special.erf", 1, 1) nfunc_spec = ("scipy.special.erf", 1, 1)
...@@ -154,19 +157,12 @@ class Erfcx(UnaryScalarOp): ...@@ -154,19 +157,12 @@ class Erfcx(UnaryScalarOp):
def c_header_dirs(self, **kwargs): def c_header_dirs(self, **kwargs):
# Using the Faddeeva.hh (c++) header for Faddeevva.cc # Using the Faddeeva.hh (c++) header for Faddeevva.cc
res = [ res = [*super().c_header_dirs(**kwargs), str(C_CODE_PATH)]
*super().c_header_dirs(**kwargs),
os.path.join(os.path.dirname(__file__), "c_code"),
]
return res return res
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
# Using Faddeeva.cc source file from: http://ab-initio.mit.edu/wiki/index.php/Faddeeva_Package # Using Faddeeva.cc source file from: http://ab-initio.mit.edu/wiki/index.php/Faddeeva_Package
with open( return (C_CODE_PATH / "Faddeeva.cc").read_text(encoding="utf-8")
os.path.join(os.path.dirname(__file__), "c_code", "Faddeeva.cc")
) as f:
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
(x,) = inp (x,) = inp
...@@ -612,9 +608,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -612,9 +608,7 @@ class Chi2SF(BinaryScalarOp):
return Chi2SF.st_impl(x, k) return Chi2SF.st_impl(x, k)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, k = inp x, k = inp
...@@ -665,9 +659,7 @@ class GammaInc(BinaryScalarOp): ...@@ -665,9 +659,7 @@ class GammaInc(BinaryScalarOp):
] ]
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
k, x = inp k, x = inp
...@@ -718,9 +710,7 @@ class GammaIncC(BinaryScalarOp): ...@@ -718,9 +710,7 @@ class GammaIncC(BinaryScalarOp):
] ]
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
k, x = inp k, x = inp
...@@ -1031,9 +1021,7 @@ class GammaU(BinaryScalarOp): ...@@ -1031,9 +1021,7 @@ class GammaU(BinaryScalarOp):
return GammaU.st_impl(k, x) return GammaU.st_impl(k, x)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
k, x = inp k, x = inp
...@@ -1069,9 +1057,7 @@ class GammaL(BinaryScalarOp): ...@@ -1069,9 +1057,7 @@ class GammaL(BinaryScalarOp):
return GammaL.st_impl(k, x) return GammaL.st_impl(k, x)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
k, x = inp k, x = inp
...@@ -1496,9 +1482,7 @@ class BetaInc(ScalarOp): ...@@ -1496,9 +1482,7 @@ class BetaInc(ScalarOp):
] ]
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "incbet.c")) as f: return (C_CODE_PATH / "incbet.c").read_text(encoding="utf-8")
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
(a, b, x) = inp (a, b, x) = inp
......
...@@ -9,7 +9,7 @@ import logging ...@@ -9,7 +9,7 @@ import logging
import os import os
import sys import sys
import textwrap import textwrap
from os.path import dirname from pathlib import Path
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.link.c.cmodule import GCC_compiler from pytensor.link.c.cmodule import GCC_compiler
...@@ -743,20 +743,15 @@ def blas_header_text(): ...@@ -743,20 +743,15 @@ def blas_header_text():
blas_code = "" blas_code = ""
if not config.blas__ldflags: if not config.blas__ldflags:
# Include the Numpy version implementation of [sd]gemm_. # Include the Numpy version implementation of [sd]gemm_.
current_filedir = dirname(__file__) current_filedir = Path(__file__).parent
blas_common_filepath = os.path.join( blas_common_filepath = current_filedir / "c_code/alt_blas_common.h"
current_filedir, "c_code", "alt_blas_common.h" blas_template_filepath = current_filedir / "c_code/alt_blas_template.c"
) try:
blas_template_filepath = os.path.join( common_code = blas_common_filepath.read_text(encoding="utf-8")
current_filedir, "c_code", "alt_blas_template.c" template_code = blas_template_filepath.read_text(encoding="utf-8")
) except OSError as err:
common_code = "" msg = "Unable to load NumPy implementation of BLAS functions from C source files."
sblas_code = "" raise OSError(msg) from err
dblas_code = ""
with open(blas_common_filepath) as code:
common_code = code.read()
with open(blas_template_filepath) as code:
template_code = code.read()
sblas_code = template_code % { sblas_code = template_code % {
"float_type": "float", "float_type": "float",
"float_size": 4, "float_size": 4,
...@@ -769,10 +764,6 @@ def blas_header_text(): ...@@ -769,10 +764,6 @@ def blas_header_text():
"npy_float": "NPY_FLOAT64", "npy_float": "NPY_FLOAT64",
"precision": "d", "precision": "d",
} }
if not (common_code and template_code):
raise OSError(
"Unable to load NumPy implementation of BLAS functions from C source files."
)
blas_code += common_code blas_code += common_code
blas_code += sblas_code blas_code += sblas_code
blas_code += dblas_code blas_code += dblas_code
......
from pathlib import Path
import numpy as np import numpy as np
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant
...@@ -38,8 +40,8 @@ class LoadFromDisk(Op): ...@@ -38,8 +40,8 @@ class LoadFromDisk(Op):
return Apply(self, [path], [tensor(dtype=self.dtype, shape=self.shape)]) return Apply(self, [path], [tensor(dtype=self.dtype, shape=self.shape)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
path = inp[0] path = Path(inp[0])
if path.split(".")[-1] == "npz": if path.suffix != ".npy":
raise ValueError(f"Expected a .npy file, got {path} instead") raise ValueError(f"Expected a .npy file, got {path} instead")
result = np.load(path, mmap_mode=self.mmap_mode) result = np.load(path, mmap_mode=self.mmap_mode)
if result.dtype != self.dtype: if result.dtype != self.dtype:
......
...@@ -174,7 +174,7 @@ def call_subprocess_Popen(command, **params): ...@@ -174,7 +174,7 @@ def call_subprocess_Popen(command, **params):
""" """
if "stdout" in params or "stderr" in params: if "stdout" in params or "stderr" in params:
raise TypeError("don't use stderr or stdout with call_subprocess_Popen") raise TypeError("don't use stderr or stdout with call_subprocess_Popen")
with open(os.devnull, "wb") as null: with Path(os.devnull).open("wb") as null:
# stdin to devnull is a workaround for a crash in a weird Windows # stdin to devnull is a workaround for a crash in a weird Windows
# environment where sys.stdin was None # environment where sys.stdin was None
params.setdefault("stdin", null) params.setdefault("stdin", null)
......
...@@ -11,27 +11,25 @@ python scripts/run_mypy.py [--verbose] ...@@ -11,27 +11,25 @@ python scripts/run_mypy.py [--verbose]
import argparse import argparse
import importlib import importlib
import os
import pathlib
import subprocess import subprocess
import sys import sys
from collections.abc import Iterator from collections.abc import Iterable
from pathlib import Path
import pandas import pandas as pd
DP_ROOT = pathlib.Path(__file__).absolute().parent.parent DP_ROOT = Path(__file__).absolute().parent.parent
FAILING = [ FAILING = [
line.strip() Path(line.strip()).absolute()
for line in (DP_ROOT / "scripts" / "mypy-failing.txt").read_text().splitlines() for line in (DP_ROOT / "scripts" / "mypy-failing.txt").read_text().splitlines()
if line.strip()
] ]
def enforce_pep561(module_name): def enforce_pep561(module_name):
try: try:
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
fp = pathlib.Path(module.__path__[0], "py.typed") fp = Path(module.__path__[0], "py.typed")
if not fp.exists(): if not fp.exists():
fp.touch() fp.touch()
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -39,13 +37,13 @@ def enforce_pep561(module_name): ...@@ -39,13 +37,13 @@ def enforce_pep561(module_name):
return return
def mypy_to_pandas(input_lines: Iterator[str]) -> pandas.DataFrame: def mypy_to_pandas(input_lines: Iterable[str]) -> pd.DataFrame:
"""Reformats mypy output with error codes to a DataFrame. """Reformats mypy output with error codes to a DataFrame.
Adapted from: https://gist.github.com/michaelosthege/24d0703e5f37850c9e5679f69598930a Adapted from: https://gist.github.com/michaelosthege/24d0703e5f37850c9e5679f69598930a
""" """
current_section = None current_section = None
data = { data: dict[str, list] = {
"file": [], "file": [],
"line": [], "line": [],
"type": [], "type": [],
...@@ -65,7 +63,7 @@ def mypy_to_pandas(input_lines: Iterator[str]) -> pandas.DataFrame: ...@@ -65,7 +63,7 @@ def mypy_to_pandas(input_lines: Iterator[str]) -> pandas.DataFrame:
message = line.replace(f"{file}:{lineno}: {message_type}: ", "").replace( message = line.replace(f"{file}:{lineno}: {message_type}: ", "").replace(
f" [{current_section}]", "" f" [{current_section}]", ""
) )
data["file"].append(file) data["file"].append(Path(file))
data["line"].append(lineno) data["line"].append(lineno)
data["type"].append(message_type) data["type"].append(message_type)
data["errorcode"].append(current_section) data["errorcode"].append(current_section)
...@@ -73,21 +71,18 @@ def mypy_to_pandas(input_lines: Iterator[str]) -> pandas.DataFrame: ...@@ -73,21 +71,18 @@ def mypy_to_pandas(input_lines: Iterator[str]) -> pandas.DataFrame:
except Exception as ex: except Exception as ex:
print(elems) print(elems)
print(ex) print(ex)
return pandas.DataFrame(data=data).set_index(["file", "line"]) return pd.DataFrame(data=data).set_index(["file", "line"])
def check_no_unexpected_results(mypy_lines: Iterator[str]): def check_no_unexpected_results(mypy_lines: Iterable[str]):
"""Compares mypy results with list of known FAILING files. """Compares mypy results with list of known FAILING files.
Exits the process with non-zero exit code upon unexpected results. Exits the process with non-zero exit code upon unexpected results.
""" """
df = mypy_to_pandas(mypy_lines) df = mypy_to_pandas(mypy_lines)
all_files = { all_files = {fp.absolute() for fp in DP_ROOT.glob("pytensor/**/*.py")}
str(fp).replace(str(DP_ROOT), "").strip(os.sep).replace(os.sep, "/") failing = {f.absolute() for f in df.reset_index().file}
for fp in DP_ROOT.glob("pytensor/**/*.py")
}
failing = set(df.reset_index().file.str.replace(os.sep, "/", regex=False))
if not failing.issubset(all_files): if not failing.issubset(all_files):
raise Exception( raise Exception(
"Mypy should have ignored these files:\n" "Mypy should have ignored these files:\n"
...@@ -141,13 +136,10 @@ if __name__ == "__main__": ...@@ -141,13 +136,10 @@ if __name__ == "__main__":
help="How to group verbose output. One of {file|errorcode|message}.", help="How to group verbose output. One of {file|errorcode|message}.",
) )
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
missing = list() missing = [path for path in FAILING if not path.exists()]
for path in FAILING:
if not os.path.exists(path):
missing.append(path)
if missing: if missing:
print("These files are missing but still kept in FAILING") print("These files are missing but still kept in FAILING")
print("\n".join(missing)) print(*missing, sep="\n")
sys.exit(1) sys.exit(1)
cp = subprocess.run( cp = subprocess.run(
["mypy", "--show-error-codes", "pytensor"], ["mypy", "--show-error-codes", "pytensor"],
......
import os
import pickle import pickle
import re import re
import shutil import shutil
import tempfile import tempfile
from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
...@@ -31,10 +31,10 @@ def test_function_dump(): ...@@ -31,10 +31,10 @@ def test_function_dump():
fct1 = function([v], v + 1) fct1 = function([v], v + 1)
try: try:
tmpdir = tempfile.mkdtemp() tmpdir = Path(tempfile.mkdtemp())
fname = os.path.join(tmpdir, "test_function_dump.pkl") fname = tmpdir / "test_function_dump.pkl"
function_dump(fname, [v], v + 1) function_dump(fname, [v], v + 1)
with open(fname, "rb") as f: with fname.open("rb") as f:
l = pickle.load(f) l = pickle.load(f)
finally: finally:
if tmpdir is not None: if tmpdir is not None:
...@@ -49,7 +49,7 @@ def test_function_name(): ...@@ -49,7 +49,7 @@ def test_function_name():
x = vector("x") x = vector("x")
func = function([x], x + 1.0) func = function([x], x + 1.0)
regex = re.compile(os.path.basename(".*test_function.pyc?")) regex = re.compile(f".*{__file__}c?")
assert regex.match(func.name) is not None assert regex.match(func.name) is not None
......
import filecmp import filecmp
import os.path as pt
import tempfile import tempfile
from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
...@@ -20,15 +20,15 @@ if not pydot_imported: ...@@ -20,15 +20,15 @@ if not pydot_imported:
class TestD3Viz: class TestD3Viz:
def setup_method(self): def setup_method(self):
self.rng = np.random.default_rng(0) self.rng = np.random.default_rng(0)
self.data_dir = pt.join("data", "test_d3viz") self.data_dir = Path("data") / "test_d3viz"
def check(self, f, reference=None, verbose=False): def check(self, f, reference=None, verbose=False):
tmp_dir = tempfile.mkdtemp() tmp_dir = Path(tempfile.mkdtemp())
html_file = pt.join(tmp_dir, "index.html") html_file = tmp_dir / "index.html"
if verbose: if verbose:
print(html_file) print(html_file)
d3v.d3viz(f, html_file) d3v.d3viz(f, html_file)
assert pt.getsize(html_file) > 0 assert html_file.stat().st_size > 0
if reference: if reference:
assert filecmp.cmp(html_file, reference) assert filecmp.cmp(html_file, reference)
......
...@@ -6,10 +6,10 @@ deterministic based on the input type and the op. ...@@ -6,10 +6,10 @@ deterministic based on the input type and the op.
""" """
import multiprocessing import multiprocessing
import os
import re import re
import sys import sys
import tempfile import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
...@@ -194,9 +194,8 @@ def cxx_search_dirs(blas_libs, mock_system): ...@@ -194,9 +194,8 @@ def cxx_search_dirs(blas_libs, mock_system):
with tempfile.TemporaryDirectory() as d: with tempfile.TemporaryDirectory() as d:
flags = None flags = None
for lib in blas_libs: for lib in blas_libs:
lib_path = os.path.join(d, libtemplate.format(lib=lib)) lib_path = Path(d) / libtemplate.format(lib=lib)
with open(lib_path, "wb") as f: lib_path.write_bytes(b"1")
f.write(b"1")
libraries.append(lib_path) libraries.append(lib_path)
if flags is None: if flags is None:
flags = f"-l{lib}" flags = f"-l{lib}"
...@@ -266,13 +265,12 @@ def windows_conda_libs(blas_libs): ...@@ -266,13 +265,12 @@ def windows_conda_libs(blas_libs):
libtemplate = "{lib}.dll" libtemplate = "{lib}.dll"
libraries = [] libraries = []
with tempfile.TemporaryDirectory() as d: with tempfile.TemporaryDirectory() as d:
subdir = os.path.join(d, "Library", "bin") subdir = Path(d) / "Library" / "bin"
os.makedirs(subdir, exist_ok=True) subdir.mkdir(exist_ok=True, parents=True)
flags = f'-L"{subdir}"' flags = f'-L"{subdir}"'
for lib in blas_libs: for lib in blas_libs:
lib_path = os.path.join(subdir, libtemplate.format(lib=lib)) lib_path = subdir / libtemplate.format(lib=lib)
with open(lib_path, "wb") as f: lib_path.write_bytes(b"1")
f.write(b"1")
libraries.append(lib_path) libraries.append(lib_path)
flags += f" -l{lib}" flags += f" -l{lib}"
if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs: if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs:
...@@ -311,14 +309,14 @@ def test_default_blas_ldflags_conda_windows( ...@@ -311,14 +309,14 @@ def test_default_blas_ldflags_conda_windows(
) )
@patch("sys.platform", "win32") @patch("sys.platform", "win32")
def test_patch_ldflags(listdir_mock): def test_patch_ldflags(listdir_mock):
mkl_path = "some_path" mkl_path = Path("some_path")
flag_list = ["-lm", "-lopenblas", f"-L {mkl_path}", "-l mkl_core", "-lmkl_rt"] flag_list = ["-lm", "-lopenblas", f"-L {mkl_path}", "-l mkl_core", "-lmkl_rt"]
assert GCC_compiler.patch_ldflags(flag_list) == [ assert GCC_compiler.patch_ldflags(flag_list) == [
"-lm", "-lm",
"-lopenblas", "-lopenblas",
f"-L {mkl_path}", f"-L {mkl_path}",
'"' + os.path.join(mkl_path, "mkl_core.1.dll") + '"', '"' + str(mkl_path / "mkl_core.1.dll") + '"',
'"' + os.path.join(mkl_path, "mkl_rt.1.0.dll") + '"', '"' + str(mkl_path / "mkl_rt.1.0.dll") + '"',
] ]
...@@ -341,8 +339,8 @@ def test_linking_patch(listdir_mock, platform): ...@@ -341,8 +339,8 @@ def test_linking_patch(listdir_mock, platform):
assert GCC_compiler.linking_patch(lib_dirs, libs) == [ assert GCC_compiler.linking_patch(lib_dirs, libs) == [
"-lopenblas", "-lopenblas",
"-lm", "-lm",
'"' + os.path.join(lib_dirs[0].strip('"'), "mkl_core.1.dll") + '"', '"' + str(Path(lib_dirs[0].strip('"')) / "mkl_core.1.dll") + '"',
'"' + os.path.join(lib_dirs[0].strip('"'), "mkl_rt.1.1.dll") + '"', '"' + str(Path(lib_dirs[0].strip('"')) / "mkl_rt.1.1.dll") + '"',
] ]
else: else:
GCC_compiler.linking_patch(lib_dirs, libs) == [ GCC_compiler.linking_patch(lib_dirs, libs) == [
......
...@@ -218,7 +218,6 @@ def test_ExternalCOp_c_code_cache_version(): ...@@ -218,7 +218,6 @@ def test_ExternalCOp_c_code_cache_version():
with tempfile.NamedTemporaryFile(dir=".", suffix=".py") as tmp: with tempfile.NamedTemporaryFile(dir=".", suffix=".py") as tmp:
tmp.write(externalcop_test_code.encode()) tmp.write(externalcop_test_code.encode())
tmp.seek(0) tmp.seek(0)
# modname = os.path.splitext(tmp.name)[0]
modname = tmp.name modname = tmp.name
out_1, err1, returncode1 = get_hash(modname, seed=428) out_1, err1, returncode1 = get_hash(modname, seed=428)
out_2, err2, returncode2 = get_hash(modname, seed=3849) out_2, err2, returncode2 = get_hash(modname, seed=3849)
......
import os from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
...@@ -175,7 +175,7 @@ class MyOpCEnumType(COp): ...@@ -175,7 +175,7 @@ class MyOpCEnumType(COp):
) )
def c_header_dirs(self, **kwargs): def c_header_dirs(self, **kwargs):
return [os.path.join(os.path.dirname(__file__), "c_code")] return [Path(__file__).parent / "c_code"]
def c_headers(self, **kwargs): def c_headers(self, **kwargs):
return ["test_cenum.h"] return ["test_cenum.h"]
......
import logging import logging
import os from pathlib import Path
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -19,6 +19,9 @@ from pytensor.tensor.type import TensorType ...@@ -19,6 +19,9 @@ from pytensor.tensor.type import TensorType
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
C_CODE_PATH = Path(__file__).parent / "c_code"
class BaseCorr3dMM(OpenMPOp, _NoPythonOp): class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
""" """
Base class for `Corr3dMM`, `Corr3dMM_gradWeights` and Base class for `Corr3dMM`, `Corr3dMM_gradWeights` and
...@@ -245,14 +248,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp): ...@@ -245,14 +248,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
sub["blas_set_num_threads"] = "" sub["blas_set_num_threads"] = ""
sub["blas_get_num_threads"] = "0" sub["blas_get_num_threads"] = "0"
final_code = "" final_code = Path(C_CODE_PATH / "corr3d_gemm.c").read_text("utf-8")
with open(
os.path.join(
os.path.split(__file__)[0], os.path.join("c_code", "corr3d_gemm.c")
)
) as f:
code = f.read()
final_code += code
return final_code % sub return final_code % sub
def c_code_helper( def c_code_helper(
......
import logging import logging
import os from pathlib import Path
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -18,6 +18,8 @@ from pytensor.tensor.type import TensorType ...@@ -18,6 +18,8 @@ from pytensor.tensor.type import TensorType
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
C_CODE_PATH = Path(__file__).parent / "c_code"
class BaseCorrMM(OpenMPOp, _NoPythonOp): class BaseCorrMM(OpenMPOp, _NoPythonOp):
""" """
...@@ -263,14 +265,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp): ...@@ -263,14 +265,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
sub["blas_set_num_threads"] = "" sub["blas_set_num_threads"] = ""
sub["blas_get_num_threads"] = "0" sub["blas_get_num_threads"] = "0"
final_code = "" final_code = (C_CODE_PATH / "corr_gemm.c").read_text("utf-8")
with open(
os.path.join(
os.path.split(__file__)[0], os.path.join("c_code", "corr_gemm.c")
)
) as f:
code = f.read()
final_code += code
return final_code % sub return final_code % sub
def c_code_helper(self, bottom, weights, top, sub, height=None, width=None): def c_code_helper(self, bottom, weights, top, sub, height=None, width=None):
......
import os
import numpy as np import numpy as np
import pytest import pytest
...@@ -13,7 +11,7 @@ from pytensor.tensor.io import load ...@@ -13,7 +11,7 @@ from pytensor.tensor.io import load
class TestLoadTensor: class TestLoadTensor:
def setup_method(self): def setup_method(self):
self.data = np.arange(5, dtype=np.int32) self.data = np.arange(5, dtype=np.int32)
self.filename = os.path.join(pytensor.config.compiledir, "_test.npy") self.filename = pytensor.config.compiledir / "_test.npy"
np.save(self.filename, self.data) np.save(self.filename, self.data)
def test_basic(self): def test_basic(self):
...@@ -54,4 +52,4 @@ class TestLoadTensor: ...@@ -54,4 +52,4 @@ class TestLoadTensor:
assert type(fn(self.filename)) == np.core.memmap assert type(fn(self.filename)) == np.core.memmap
def teardown_method(self): def teardown_method(self):
os.remove(os.path.join(pytensor.config.compiledir, "_test.npy")) (pytensor.config.compiledir / "_test.npy").unlink()
import os.path as path from pathlib import Path
from tempfile import mkdtemp from tempfile import mkdtemp
import numpy as np import numpy as np
...@@ -187,7 +187,7 @@ def test_filter_memmap(): ...@@ -187,7 +187,7 @@ def test_filter_memmap():
r"""Make sure `TensorType.filter` can handle NumPy `memmap`\s subclasses.""" r"""Make sure `TensorType.filter` can handle NumPy `memmap`\s subclasses."""
data = np.arange(12, dtype=config.floatX) data = np.arange(12, dtype=config.floatX)
data.resize((3, 4)) data.resize((3, 4))
filename = path.join(mkdtemp(), "newfile.dat") filename = Path(mkdtemp()) / "newfile.dat"
fp = np.memmap(filename, dtype=config.floatX, mode="w+", shape=(3, 4)) fp = np.memmap(filename, dtype=config.floatX, mode="w+", shape=(3, 4))
test_type = TensorType(config.floatX, shape=(None, None)) test_type = TensorType(config.floatX, shape=(None, None))
......
...@@ -138,7 +138,7 @@ class TestConfigTypes: ...@@ -138,7 +138,7 @@ class TestConfigTypes:
cp._apply("gpu123") cp._apply("gpu123")
with pytest.raises(ValueError, match='Valid options start with one of "cpu".'): with pytest.raises(ValueError, match='Valid options start with one of "cpu".'):
cp._apply("notadevice") cp._apply("notadevice")
assert str(cp) == "None (cpu)" assert str(cp) == "unnamed (cpu)"
def test_config_context(): def test_config_context():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论