提交 6a295b9f authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Remove unused stuff and type pytensor/utils.py

上级 b083fb91
...@@ -426,7 +426,7 @@ def is_same_entry(entry_1, entry_2): ...@@ -426,7 +426,7 @@ def is_same_entry(entry_1, entry_2):
return False return False
def get_module_hash(src_code, key): def get_module_hash(src_code: str, key) -> str:
""" """
Return a SHA256 hash that uniquely identifies a module. Return a SHA256 hash that uniquely identifies a module.
...@@ -466,13 +466,13 @@ def get_module_hash(src_code, key): ...@@ -466,13 +466,13 @@ def get_module_hash(src_code, key):
if isinstance(key_element, tuple): if isinstance(key_element, tuple):
# This should be the C++ compilation command line parameters or the # This should be the C++ compilation command line parameters or the
# libraries to link against. # libraries to link against.
to_hash += list(key_element) to_hash += [str(e) for e in key_element]
elif isinstance(key_element, str): elif isinstance(key_element, str):
if key_element.startswith("md5:") or key_element.startswith("hash:"): if key_element.startswith("md5:") or key_element.startswith("hash:"):
# This is actually a sha256 hash of the config options. # This is actually a sha256 hash of the config options.
# Currently, we still keep md5 to don't break old PyTensor. # Currently, we still keep md5 to don't break old PyTensor.
# We add 'hash:' so that when we change it in # We add 'hash:' so that when we change it in
# the futur, it won't break this version of PyTensor. # the future, it won't break this version of PyTensor.
break break
elif key_element.startswith("NPY_ABI_VERSION=0x") or key_element.startswith( elif key_element.startswith("NPY_ABI_VERSION=0x") or key_element.startswith(
"c_compiler_str=" "c_compiler_str="
......
...@@ -75,6 +75,7 @@ Optimizations associated with these BLAS Ops are in tensor.rewriting.blas ...@@ -75,6 +75,7 @@ Optimizations associated with these BLAS Ops are in tensor.rewriting.blas
""" """
import functools
import logging import logging
import os import os
import time import time
...@@ -104,7 +105,6 @@ from pytensor.tensor.elemwise import DimShuffle ...@@ -104,7 +105,6 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub from pytensor.tensor.math import add, mul, neg, sub
from pytensor.tensor.shape import shape_padright, specify_broadcastable from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
from pytensor.utils import memoize
_logger = logging.getLogger("pytensor.tensor.blas") _logger = logging.getLogger("pytensor.tensor.blas")
...@@ -365,8 +365,10 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): ...@@ -365,8 +365,10 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
) )
@memoize @functools.cache
def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir): def _ldflags(
ldflags_str: str, libs: bool, flags: bool, libs_dir: bool, include_dir: bool
) -> list[str]:
"""Extract list of compilation flags from a string. """Extract list of compilation flags from a string.
Depending on the options, different type of flags will be kept. Depending on the options, different type of flags will be kept.
...@@ -422,7 +424,7 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir): ...@@ -422,7 +424,7 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
t = t[1:-1] t = t[1:-1]
try: try:
t0, t1, t2 = t[0:3] t0, t1 = t[0], t[1]
assert t0 == "-" assert t0 == "-"
except Exception: except Exception:
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"') raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
...@@ -435,7 +437,6 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir): ...@@ -435,7 +437,6 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
" is not wanted.", " is not wanted.",
t, t,
) )
rval.append(t[2:])
elif libs and t1 == "l": # example -lmkl elif libs and t1 == "l": # example -lmkl
rval.append(t[2:]) rval.append(t[2:])
elif flags and t1 not in ("L", "I", "l"): # example -openmp elif flags and t1 not in ("L", "I", "l"): # example -openmp
......
...@@ -6,7 +6,9 @@ import os ...@@ -6,7 +6,9 @@ import os
import struct import struct
import subprocess import subprocess
import sys import sys
from collections.abc import Iterable, Sequence
from functools import partial from functools import partial
from pathlib import Path
__all__ = [ __all__ = [
...@@ -85,18 +87,6 @@ def add_excepthook(hook): ...@@ -85,18 +87,6 @@ def add_excepthook(hook):
sys.excepthook = __call_excepthooks sys.excepthook = __call_excepthooks
def exc_message(e):
"""
In python 3.x, when an exception is reraised it saves original
exception in its args, therefore in order to find the actual
message, we need to unpack arguments recursively.
"""
msg = e.args[0]
if isinstance(msg, Exception):
return exc_message(msg)
return msg
def get_unbound_function(unbound): def get_unbound_function(unbound):
# Op.make_thunk isn't bound, so don't have a __func__ attr. # Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the # But bound method, have a __func__ method that point to the
...@@ -106,8 +96,9 @@ def get_unbound_function(unbound): ...@@ -106,8 +96,9 @@ def get_unbound_function(unbound):
return unbound return unbound
def maybe_add_to_os_environ_pathlist(var, newpath): def maybe_add_to_os_environ_pathlist(var: str, newpath: Path | str) -> None:
"""Unfortunately, Conda offers to make itself the default Python """
Unfortunately, Conda offers to make itself the default Python
and those who use it that way will probably not activate envs and those who use it that way will probably not activate envs
correctly meaning e.g. mingw-w64 g++ may not be on their PATH. correctly meaning e.g. mingw-w64 g++ may not be on their PATH.
...@@ -118,18 +109,18 @@ def maybe_add_to_os_environ_pathlist(var, newpath): ...@@ -118,18 +109,18 @@ def maybe_add_to_os_environ_pathlist(var, newpath):
The reason we check first is because Windows environment vars The reason we check first is because Windows environment vars
are limited to 8191 characters and it is easy to hit that. are limited to 8191 characters and it is easy to hit that.
`var` will typically be 'PATH'.""" `var` will typically be 'PATH'.
"""
import os if not Path(newpath).is_absolute():
return
if os.path.isabs(newpath): try:
try: oldpaths = os.environ[var].split(os.pathsep)
oldpaths = os.environ[var].split(os.pathsep) if str(newpath) not in oldpaths:
if newpath not in oldpaths: newpaths = os.pathsep.join([str(newpath), *oldpaths])
newpaths = os.pathsep.join([newpath, *oldpaths]) os.environ[var] = newpaths
os.environ[var] = newpaths except Exception:
except Exception: pass
pass
def subprocess_Popen(command, **params): def subprocess_Popen(command, **params):
...@@ -210,7 +201,7 @@ def output_subprocess_Popen(command, **params): ...@@ -210,7 +201,7 @@ def output_subprocess_Popen(command, **params):
return (*out, p.returncode) return (*out, p.returncode)
def hash_from_code(msg): def hash_from_code(msg: str | bytes) -> str:
"""Return the SHA256 hash of a string or bytes.""" """Return the SHA256 hash of a string or bytes."""
# hashlib.sha256() requires an object that supports buffer interface, # hashlib.sha256() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't. # but Python 3 (unicode) strings don't.
...@@ -221,27 +212,7 @@ def hash_from_code(msg): ...@@ -221,27 +212,7 @@ def hash_from_code(msg):
return "m" + hashlib.sha256(msg).hexdigest() return "m" + hashlib.sha256(msg).hexdigest()
def memoize(f): def uniq(seq: Sequence) -> list:
"""
Cache the return value for each tuple of arguments (which must be hashable).
"""
cache = {}
def rval(*args, **kwargs):
kwtup = tuple(kwargs.items())
key = (args, kwtup)
if key not in cache:
val = f(*args, **kwargs)
cache[key] = val
else:
val = cache[key]
return val
return rval
def uniq(seq):
""" """
Do not use set, this must always return the same value at the same index. Do not use set, this must always return the same value at the same index.
If we just exchange other values, but keep the same pattern of duplication, If we just exchange other values, but keep the same pattern of duplication,
...@@ -253,11 +224,12 @@ def uniq(seq): ...@@ -253,11 +224,12 @@ def uniq(seq):
return [x for i, x in enumerate(seq) if seq.index(x) == i] return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2): def difference(seq1: Iterable, seq2: Iterable):
r""" r"""
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``. Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
""" """
seq2 = list(seq2)
try: try:
# try to use O(const * len(seq1)) algo # try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB if len(seq2) < 4: # I'm guessing this threshold -JB
...@@ -285,7 +257,7 @@ def from_return_values(values): ...@@ -285,7 +257,7 @@ def from_return_values(values):
return [values] return [values]
def flatten(a): def flatten(a) -> list:
""" """
Recursively flatten tuple, list and set in a list. Recursively flatten tuple, list and set in a list.
......
...@@ -31,7 +31,6 @@ from pytensor.tensor.type import ( ...@@ -31,7 +31,6 @@ from pytensor.tensor.type import (
scalars, scalars,
vector, vector,
) )
from pytensor.utils import exc_message
def PatternOptimizer(p1, p2, ign=True): def PatternOptimizer(p1, p2, ign=True):
...@@ -1182,6 +1181,17 @@ class TestPicklefunction: ...@@ -1182,6 +1181,17 @@ class TestPicklefunction:
def pers_load(id): def pers_load(id):
return saves[id] return saves[id]
def exc_message(e):
"""
In Python 3, when an exception is reraised it saves the original
exception in its args, therefore in order to find the actual
message, we need to unpack arguments recursively.
"""
msg = e.args[0]
if isinstance(msg, Exception):
return exc_message(msg)
return msg
b = np.random.random((5, 4)) b = np.random.random((5, 4))
x = matrix() x = matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论