提交 d0e35832 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.link.c.cmodule

上级 41c10974
...@@ -19,9 +19,10 @@ import textwrap ...@@ -19,9 +19,10 @@ import textwrap
import time import time
import warnings import warnings
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import Dict, List, Set from typing import Callable, Dict, List, Optional, Set, Tuple, cast
import numpy.distutils import numpy.distutils
from typing_extensions import Protocol
# we will abuse the lockfile mechanism when reading and writing the registry # we will abuse the lockfile mechanism when reading and writing the registry
from aesara.compile.compilelock import lock_ctx from aesara.compile.compilelock import lock_ctx
...@@ -39,6 +40,26 @@ from aesara.utils import ( ...@@ -39,6 +40,26 @@ from aesara.utils import (
) )
class StdLibDirsAndLibsType(Protocol):
data: Optional[Tuple[List[str], ...]]
__call__: Callable[[], Optional[Tuple[List[str], ...]]]
def is_StdLibDirsAndLibsType(
fn: Callable[[], Optional[Tuple[List[str], ...]]]
) -> StdLibDirsAndLibsType:
return cast(StdLibDirsAndLibsType, fn)
class GCCLLVMType(Protocol):
is_llvm: Optional[bool]
__call__: Callable[[], Optional[bool]]
def is_GCCLLVMType(fn: Callable[[], Optional[bool]]) -> GCCLLVMType:
return cast(GCCLLVMType, fn)
_logger = logging.getLogger("aesara.link.c.cmodule") _logger = logging.getLogger("aesara.link.c.cmodule")
METH_VARARGS = "METH_VARARGS" METH_VARARGS = "METH_VARARGS"
...@@ -1649,7 +1670,8 @@ def std_include_dirs(): ...@@ -1649,7 +1670,8 @@ def std_include_dirs():
return numpy_inc_dirs + python_inc_dirs + [gof_inc_dir] return numpy_inc_dirs + python_inc_dirs + [gof_inc_dir]
def std_lib_dirs_and_libs(): @is_StdLibDirsAndLibsType
def std_lib_dirs_and_libs() -> Optional[Tuple[List[str], ...]]:
# We cache the results as on Windows, this trigger file access and # We cache the results as on Windows, this trigger file access and
# this method is called many times. # this method is called many times.
if std_lib_dirs_and_libs.data is not None: if std_lib_dirs_and_libs.data is not None:
...@@ -1730,7 +1752,7 @@ def std_lib_dirs_and_libs(): ...@@ -1730,7 +1752,7 @@ def std_lib_dirs_and_libs():
# get the name of the python library (shared object) # get the name of the python library (shared object)
libname = distutils.sysconfig.get_config_var("LDLIBRARY") libname = str(distutils.sysconfig.get_config_var("LDLIBRARY"))
if libname.startswith("lib"): if libname.startswith("lib"):
libname = libname[3:] libname = libname[3:]
...@@ -1741,7 +1763,7 @@ def std_lib_dirs_and_libs(): ...@@ -1741,7 +1763,7 @@ def std_lib_dirs_and_libs():
elif libname.endswith(".a"): elif libname.endswith(".a"):
libname = libname[:-2] libname = libname[:-2]
libdir = distutils.sysconfig.get_config_var("LIBDIR") libdir = str(distutils.sysconfig.get_config_var("LIBDIR"))
std_lib_dirs_and_libs.data = [libname], [libdir] std_lib_dirs_and_libs.data = [libname], [libdir]
...@@ -1749,7 +1771,9 @@ def std_lib_dirs_and_libs(): ...@@ -1749,7 +1771,9 @@ def std_lib_dirs_and_libs():
# explicitly where it is located this returns # explicitly where it is located this returns
# somepath/lib/python2.x # somepath/lib/python2.x
python_lib = distutils.sysconfig.get_python_lib(plat_specific=1, standard_lib=1) python_lib = str(
distutils.sysconfig.get_python_lib(plat_specific=True, standard_lib=True)
)
python_lib = os.path.dirname(python_lib) python_lib = os.path.dirname(python_lib)
if python_lib not in std_lib_dirs_and_libs.data[1]: if python_lib not in std_lib_dirs_and_libs.data[1]:
std_lib_dirs_and_libs.data[1].append(python_lib) std_lib_dirs_and_libs.data[1].append(python_lib)
...@@ -1771,7 +1795,8 @@ def gcc_version(): ...@@ -1771,7 +1795,8 @@ def gcc_version():
return gcc_version_str return gcc_version_str
def gcc_llvm(): @is_GCCLLVMType
def gcc_llvm() -> Optional[bool]:
""" """
Detect if the g++ version used is the llvm one or not. Detect if the g++ version used is the llvm one or not.
......
...@@ -115,10 +115,6 @@ check_untyped_defs = False ...@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.link.c.cmodule]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.c.cvm] [mypy-aesara.link.c.cvm]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论