提交 6a295b9f authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Remove unused stuff and type pytensor/utils.py

上级 b083fb91
......@@ -426,7 +426,7 @@ def is_same_entry(entry_1, entry_2):
return False
def get_module_hash(src_code, key):
def get_module_hash(src_code: str, key) -> str:
"""
Return a SHA256 hash that uniquely identifies a module.
......@@ -466,13 +466,13 @@ def get_module_hash(src_code, key):
if isinstance(key_element, tuple):
# This should be the C++ compilation command line parameters or the
# libraries to link against.
to_hash += list(key_element)
to_hash += [str(e) for e in key_element]
elif isinstance(key_element, str):
if key_element.startswith("md5:") or key_element.startswith("hash:"):
# This is actually a sha256 hash of the config options.
# Currently, we still keep md5 to don't break old PyTensor.
# We add 'hash:' so that when we change it in
# the futur, it won't break this version of PyTensor.
# the future, it won't break this version of PyTensor.
break
elif key_element.startswith("NPY_ABI_VERSION=0x") or key_element.startswith(
"c_compiler_str="
......
......@@ -75,6 +75,7 @@ Optimizations associated with these BLAS Ops are in tensor.rewriting.blas
"""
import functools
import logging
import os
import time
......@@ -104,7 +105,6 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub
from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
from pytensor.utils import memoize
_logger = logging.getLogger("pytensor.tensor.blas")
......@@ -365,8 +365,10 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
)
@memoize
def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
@functools.cache
def _ldflags(
ldflags_str: str, libs: bool, flags: bool, libs_dir: bool, include_dir: bool
) -> list[str]:
"""Extract list of compilation flags from a string.
Depending on the options, different type of flags will be kept.
......@@ -422,7 +424,7 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
t = t[1:-1]
try:
t0, t1, t2 = t[0:3]
t0, t1 = t[0], t[1]
assert t0 == "-"
except Exception:
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
......@@ -435,7 +437,6 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
" is not wanted.",
t,
)
rval.append(t[2:])
elif libs and t1 == "l": # example -lmkl
rval.append(t[2:])
elif flags and t1 not in ("L", "I", "l"): # example -openmp
......
......@@ -6,7 +6,9 @@ import os
import struct
import subprocess
import sys
from collections.abc import Iterable, Sequence
from functools import partial
from pathlib import Path
__all__ = [
......@@ -85,18 +87,6 @@ def add_excepthook(hook):
sys.excepthook = __call_excepthooks
def exc_message(e):
"""
In python 3.x, when an exception is reraised it saves original
exception in its args, therefore in order to find the actual
message, we need to unpack arguments recursively.
"""
msg = e.args[0]
if isinstance(msg, Exception):
return exc_message(msg)
return msg
def get_unbound_function(unbound):
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
......@@ -106,8 +96,9 @@ def get_unbound_function(unbound):
return unbound
def maybe_add_to_os_environ_pathlist(var, newpath):
"""Unfortunately, Conda offers to make itself the default Python
def maybe_add_to_os_environ_pathlist(var: str, newpath: Path | str) -> None:
"""
Unfortunately, Conda offers to make itself the default Python
and those who use it that way will probably not activate envs
correctly meaning e.g. mingw-w64 g++ may not be on their PATH.
......@@ -118,18 +109,18 @@ def maybe_add_to_os_environ_pathlist(var, newpath):
The reason we check first is because Windows environment vars
are limited to 8191 characters and it is easy to hit that.
`var` will typically be 'PATH'."""
import os
`var` will typically be 'PATH'.
"""
if not Path(newpath).is_absolute():
return
if os.path.isabs(newpath):
try:
oldpaths = os.environ[var].split(os.pathsep)
if newpath not in oldpaths:
newpaths = os.pathsep.join([newpath, *oldpaths])
os.environ[var] = newpaths
except Exception:
pass
try:
oldpaths = os.environ[var].split(os.pathsep)
if str(newpath) not in oldpaths:
newpaths = os.pathsep.join([str(newpath), *oldpaths])
os.environ[var] = newpaths
except Exception:
pass
def subprocess_Popen(command, **params):
......@@ -210,7 +201,7 @@ def output_subprocess_Popen(command, **params):
return (*out, p.returncode)
def hash_from_code(msg):
def hash_from_code(msg: str | bytes) -> str:
"""Return the SHA256 hash of a string or bytes."""
# hashlib.sha256() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
......@@ -221,27 +212,7 @@ def hash_from_code(msg):
return "m" + hashlib.sha256(msg).hexdigest()
def memoize(f):
"""
Cache the return value for each tuple of arguments (which must be hashable).
"""
cache = {}
def rval(*args, **kwargs):
kwtup = tuple(kwargs.items())
key = (args, kwtup)
if key not in cache:
val = f(*args, **kwargs)
cache[key] = val
else:
val = cache[key]
return val
return rval
def uniq(seq):
def uniq(seq: Sequence) -> list:
"""
Do not use set, this must always return the same value at the same index.
If we just exchange other values, but keep the same pattern of duplication,
......@@ -253,11 +224,12 @@ def uniq(seq):
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
def difference(seq1: Iterable, seq2: Iterable):
r"""
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
"""
seq2 = list(seq2)
try:
# try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB
......@@ -285,7 +257,7 @@ def from_return_values(values):
return [values]
def flatten(a):
def flatten(a) -> list:
"""
Recursively flatten tuple, list and set in a list.
......
......@@ -31,7 +31,6 @@ from pytensor.tensor.type import (
scalars,
vector,
)
from pytensor.utils import exc_message
def PatternOptimizer(p1, p2, ign=True):
......@@ -1182,6 +1181,17 @@ class TestPicklefunction:
def pers_load(id):
return saves[id]
def exc_message(e):
"""
In Python 3, when an exception is reraised it saves the original
exception in its args, therefore in order to find the actual
message, we need to unpack arguments recursively.
"""
msg = e.args[0]
if isinstance(msg, Exception):
return exc_message(msg)
return msg
b = np.random.random((5, 4))
x = matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论