提交 3409264d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lazy setuptools imports

上级 79eee675
...@@ -26,19 +26,12 @@ from pathlib import Path ...@@ -26,19 +26,12 @@ from pathlib import Path
from typing import TYPE_CHECKING, Protocol, cast from typing import TYPE_CHECKING, Protocol, cast
import numpy as np import numpy as np
from setuptools._distutils.sysconfig import (
get_config_h_filename,
get_config_var,
get_python_inc,
get_python_lib,
)
# we will abuse the lockfile mechanism when reading and writing the registry # we will abuse the lockfile mechanism when reading and writing the registry
from pytensor.compile.compilelock import lock_ctx from pytensor.compile.compilelock import lock_ctx
from pytensor.configdefaults import config, gcc_version_str from pytensor.configdefaults import config, gcc_version_str
from pytensor.configparser import BoolParam, StrParam from pytensor.configparser import BoolParam, StrParam
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.c.exceptions import CompileError, MissingGXX
from pytensor.utils import ( from pytensor.utils import (
LOCAL_BITWIDTH, LOCAL_BITWIDTH,
flatten, flatten,
...@@ -266,6 +259,8 @@ class DynamicModule: ...@@ -266,6 +259,8 @@ class DynamicModule:
def _get_ext_suffix(): def _get_ext_suffix():
"""Get the suffix for compiled extensions""" """Get the suffix for compiled extensions"""
from setuptools._distutils.sysconfig import get_config_var
dist_suffix = get_config_var("EXT_SUFFIX") dist_suffix = get_config_var("EXT_SUFFIX")
if dist_suffix is None: if dist_suffix is None:
dist_suffix = get_config_var("SO") dist_suffix = get_config_var("SO")
...@@ -1697,6 +1692,8 @@ def get_gcc_shared_library_arg(): ...@@ -1697,6 +1692,8 @@ def get_gcc_shared_library_arg():
def std_include_dirs(): def std_include_dirs():
from setuptools._distutils.sysconfig import get_python_inc
numpy_inc_dirs = [np.get_include()] numpy_inc_dirs = [np.get_include()]
py_inc = get_python_inc() py_inc = get_python_inc()
py_plat_spec_inc = get_python_inc(plat_specific=True) py_plat_spec_inc = get_python_inc(plat_specific=True)
...@@ -1709,6 +1706,12 @@ def std_include_dirs(): ...@@ -1709,6 +1706,12 @@ def std_include_dirs():
@is_StdLibDirsAndLibsType @is_StdLibDirsAndLibsType
def std_lib_dirs_and_libs() -> tuple[list[str], ...] | None: def std_lib_dirs_and_libs() -> tuple[list[str], ...] | None:
from setuptools._distutils.sysconfig import (
get_config_var,
get_python_inc,
get_python_lib,
)
# We cache the results as on Windows, this trigger file access and # We cache the results as on Windows, this trigger file access and
# this method is called many times. # this method is called many times.
if std_lib_dirs_and_libs.data is not None: if std_lib_dirs_and_libs.data is not None:
...@@ -2388,23 +2391,6 @@ class GCC_compiler(Compiler): ...@@ -2388,23 +2391,6 @@ class GCC_compiler(Compiler):
# xcode's version. # xcode's version.
cxxflags.append("-ld64") cxxflags.append("-ld64")
if sys.platform == "win32":
# Workaround for https://github.com/Theano/Theano/issues/4926.
# https://github.com/python/cpython/pull/11283/ removed the "hypot"
# redefinition for recent CPython versions (>=2.7.16 and >=3.7.3).
# The following nullifies that redefinition, if it is found.
python_version = sys.version_info[:3]
if (3,) <= python_version < (3, 7, 3):
config_h_filename = get_config_h_filename()
try:
with open(config_h_filename) as config_h:
if any(
line.startswith("#define hypot _hypot") for line in config_h
):
cxxflags.append("-D_hypot=hypot")
except OSError:
pass
return cxxflags return cxxflags
@classmethod @classmethod
...@@ -2555,8 +2541,9 @@ class GCC_compiler(Compiler): ...@@ -2555,8 +2541,9 @@ class GCC_compiler(Compiler):
""" """
# TODO: Do not do the dlimport in this function # TODO: Do not do the dlimport in this function
if not config.cxx: if not config.cxx:
from pytensor.link.c.exceptions import MissingGXX
raise MissingGXX("g++ not available! We can't compile c code.") raise MissingGXX("g++ not available! We can't compile c code.")
if include_dirs is None: if include_dirs is None:
...@@ -2586,6 +2573,8 @@ class GCC_compiler(Compiler): ...@@ -2586,6 +2573,8 @@ class GCC_compiler(Compiler):
cppfile.write("\n") cppfile.write("\n")
if platform.python_implementation() == "PyPy": if platform.python_implementation() == "PyPy":
from setuptools._distutils.sysconfig import get_config_var
suffix = "." + get_lib_extension() suffix = "." + get_lib_extension()
dist_suffix = get_config_var("SO") dist_suffix = get_config_var("SO")
...@@ -2642,6 +2631,8 @@ class GCC_compiler(Compiler): ...@@ -2642,6 +2631,8 @@ class GCC_compiler(Compiler):
status = p_out[2] status = p_out[2]
if status: if status:
from pytensor.link.c.exceptions import CompileError
tf = tempfile.NamedTemporaryFile( tf = tempfile.NamedTemporaryFile(
mode="w", prefix="pytensor_compilation_error_", delete=False mode="w", prefix="pytensor_compilation_error_", delete=False
) )
......
...@@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, Any ...@@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, Any
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.link.basic import Container, LocalLinker from pytensor.link.basic import Container, LocalLinker
from pytensor.link.c.exceptions import MissingGXX
from pytensor.link.utils import ( from pytensor.link.utils import (
gc_helper, gc_helper,
get_destroy_dependencies, get_destroy_dependencies,
...@@ -1006,6 +1005,8 @@ class VMLinker(LocalLinker): ...@@ -1006,6 +1005,8 @@ class VMLinker(LocalLinker):
compute_map, compute_map,
updated_vars, updated_vars,
): ):
from pytensor.link.c.exceptions import MissingGXX
pre_call_clear = [storage_map[v] for v in self.no_recycling] pre_call_clear = [storage_map[v] for v in self.no_recycling]
try: try:
......
...@@ -74,7 +74,6 @@ from pytensor.graph.op import HasInnerGraph, Op ...@@ -74,7 +74,6 @@ from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker from pytensor.link.c.basic import CLinker
from pytensor.link.c.exceptions import MissingGXX
from pytensor.printing import op_debug_information from pytensor.printing import op_debug_information
from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
...@@ -1499,6 +1498,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1499,6 +1498,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
then it must not do so for variables in the no_recycling list. then it must not do so for variables in the no_recycling list.
""" """
from pytensor.link.c.exceptions import MissingGXX
# Before building the thunk, validate that the inner graph is # Before building the thunk, validate that the inner graph is
# coherent # coherent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论