提交 e9080d89 authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Maxim Kochurov

scan as Cython extension

上级 824294dd
...@@ -4,6 +4,7 @@ __pycache__ ...@@ -4,6 +4,7 @@ __pycache__
.coverage .coverage
*.linkinfo *.linkinfo
*.o *.o
*.c
*.orig *.orig
*.pyc *.pyc
*.pyo *.pyo
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -3425,7 +3425,7 @@ def profile_printer( ...@@ -3425,7 +3425,7 @@ def profile_printer(
) )
@op_debug_information.register(Scan) @op_debug_information.register(Scan) # noqa
def _op_debug_information_Scan(op: Scan, node: Apply): def _op_debug_information_Scan(op: Scan, node: Apply):
from typing import Sequence from typing import Sequence
......
""" """
To update the `Scan` Cython code you must To update the `Scan` Cython code you must
- update the version value in this file and `scan_perform.py`, and - Update `scan_perform.pyx`
- run `cython scan_perform.pyx; mv scan_perform.c c_code` - update the version value in this file and in `scan_perform.pyx`
""" """
import logging from pytensor.scan.scan_perform import get_version, perform # noqa: F401, E402
import os
import sys
from importlib import reload
from types import ModuleType
from typing import Optional
import pytensor
from pytensor.compile.compilelock import lock_ctx
from pytensor.configdefaults import config
from pytensor.link.c import cmodule
if not config.cxx:
raise ImportError("No C compiler; cannot compile Cython-generated code")
_logger = logging.getLogger("pytensor.scan.scan_perform")
version = 0.326 # must match constant returned in function get_version() version = 0.326 # must match constant returned in function get_version()
assert version == get_version(), (
need_reload = False "Invalid extension, check the installation process, "
scan_perform: Optional[ModuleType] = None "could be problem with .pyx file or Cython ext build process."
)
del get_version
def try_import():
global scan_perform
sys.path[0:0] = [config.compiledir]
import scan_perform
del sys.path[0]
def try_reload():
sys.path[0:0] = [config.compiledir]
reload(scan_perform)
del sys.path[0]
try:
try_import()
need_reload = True
if version != getattr(scan_perform, "_version", None):
raise ImportError("Scan code version mismatch")
except ImportError:
dirname = "scan_perform"
loc = os.path.join(config.compiledir, dirname)
os.makedirs(loc, exist_ok=True)
with lock_ctx(loc):
# Maybe someone else already finished compiling it while we were
# waiting for the lock?
try:
if need_reload:
# The module was successfully imported earlier: we need to
# reload it to check if the version was updated.
try_reload()
else:
try_import()
need_reload = True
if version != getattr(scan_perform, "_version", None):
raise ImportError()
except ImportError:
_logger.info("Compiling C code for scan")
cfile = os.path.join(
pytensor.__path__[0], "scan", "c_code", "scan_perform.c"
)
if not os.path.exists(cfile):
raise ImportError(
"The file scan_perform.c is not available, so scan "
"will not use its Cython implementation."
)
preargs = ["-fwrapv", "-O2", "-fno-strict-aliasing"]
preargs += cmodule.GCC_compiler.compile_args()
with open(cfile) as f:
code = f.read()
cmodule.GCC_compiler.compile_str(
dirname, code, location=loc, preargs=preargs, hide_symbols=False
)
# Save version into the __init__.py file.
init_py = os.path.join(loc, "__init__.py")
with open(init_py, "w") as f:
f.write(f"_version = {version}\n")
# If we just compiled the module for the first time, then it was
# imported at the same time. We need to make sure we do not reload
# the now outdated __init__.pyc below.
init_pyc = os.path.join(loc, "__init__.pyc")
if os.path.isfile(init_pyc):
os.remove(init_pyc)
try_import()
try_reload()
from scan_perform import scan_perform as scan_c
assert (
scan_perform is not None
and scan_perform._version == scan_c.get_version()
)
_logger.info(f"New version {scan_perform._version}")
from scan_perform.scan_perform import get_version, perform # noqa: F401, E402
assert version == get_version()
#!/usr/bin/env python #!/usr/bin/env python
import os import os
from setuptools import setup import numpy
from setuptools import Extension, setup
from setuptools.dist import Distribution from setuptools.dist import Distribution
import versioneer import versioneer
...@@ -35,4 +36,11 @@ if __name__ == "__main__": ...@@ -35,4 +36,11 @@ if __name__ == "__main__":
name=NAME, name=NAME,
version=versioneer.get_version(), version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(), cmdclass=versioneer.get_cmdclass(),
ext_modules=[
Extension(
name="pytensor.scan.scan_perform",
sources=["pytensor/scan/scan_perform.pyx"],
include_dirs=[numpy.get_include()],
),
],
) )
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论