提交 175e7843 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update scan's Cython version and restrict it's compilation lock scope

When `Scan`'s Cython code is compiled, it will now only hold a lock for its compilation directory and not unnecessarily block other Theano C compilation processes.
上级 114e1237
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -5,27 +5,23 @@ To update the `Scan` Cython code you must ...@@ -5,27 +5,23 @@ To update the `Scan` Cython code you must
- run `cython scan_perform.pyx; mv scan_perform.c c_code` - run `cython scan_perform.pyx; mv scan_perform.c c_code`
""" """
import errno
import logging import logging
import os import os
import sys import sys
import warnings
from importlib import reload from importlib import reload
import numpy as np
import theano import theano
from theano.compile.compilelock import lock_ctx from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config from theano.configdefaults import config
from theano.link.c import cmodule from theano.link.c import cmodule
_logger = logging.getLogger("theano.scan.scan_perform") if not config.cxx:
raise ImportError("No C compiler; cannot compile Cython-generated code")
_logger = logging.getLogger("theano.scan.scan_perform")
version = 0.297 # must match constant returned in function get_version() version = 0.298 # must match constant returned in function get_version()
need_reload = False need_reload = False
...@@ -48,9 +44,15 @@ try: ...@@ -48,9 +44,15 @@ try:
try_import() try_import()
need_reload = True need_reload = True
if version != getattr(scan_perform, "_version", None): if version != getattr(scan_perform, "_version", None):
raise ImportError() raise ImportError("Scan code version mismatch")
except ImportError: except ImportError:
with lock_ctx():
dirname = "scan_perform"
loc = os.path.join(config.compiledir, dirname)
os.makedirs(loc, exist_ok=True)
with lock_ctx(loc):
# Maybe someone else already finished compiling it while we were # Maybe someone else already finished compiling it while we were
# waiting for the lock? # waiting for the lock?
try: try:
...@@ -61,87 +63,55 @@ except ImportError: ...@@ -61,87 +63,55 @@ except ImportError:
else: else:
try_import() try_import()
need_reload = True need_reload = True
if version != getattr(scan_perform, "_version", None): if version != getattr(scan_perform, "_version", None):
raise ImportError() raise ImportError()
except ImportError: except ImportError:
if not config.cxx:
raise ImportError("no c compiler, can't compile cython code")
_logger.info("Compiling C code for scan") _logger.info("Compiling C code for scan")
dirname = "scan_perform"
cfile = os.path.join(theano.__path__[0], "scan", "c_code", "scan_perform.c") cfile = os.path.join(theano.__path__[0], "scan", "c_code", "scan_perform.c")
if not os.path.exists(cfile): if not os.path.exists(cfile):
# This can happen in not normal case. We just raise ImportError(
# disable the cython code. If we are here the user "The file scan_perform.c is not available, so scan "
# didn't disable the compiler, so print a warning. "will not use its Cython implementation."
warnings.warn(
"The file scan_perform.c is not available. This do"
"not happen normally. You are probably in a strange"
"setup. This mean Theano can not use the cython code for "
"scan. If you"
"want to remove this warning, use the Theano flag"
"'cxx=' (set to an empty string) to disable all c"
"code generation."
) )
raise ImportError("The file lazylinker_c.c is not available.")
with open(cfile) as f:
code = f.read()
loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc):
try:
os.mkdir(loc)
except OSError as e:
assert e.errno == errno.EEXIST
assert os.path.exists(loc)
preargs = ["-fwrapv", "-O2", "-fno-strict-aliasing"] preargs = ["-fwrapv", "-O2", "-fno-strict-aliasing"]
preargs += cmodule.GCC_compiler.compile_args() preargs += cmodule.GCC_compiler.compile_args()
# Cython 19.1 always use the old NumPy interface. So we
# need to manually modify the .c file to get it compiled with open(cfile) as f:
# by Theano. As by default, we tell NumPy to don't import code = f.read()
# the old interface.
if False:
# During scan cython development, it is helpful to keep the old interface, to don't manually edit the c file each time.
preargs.remove("-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION")
else:
numpy_ver = [int(n) for n in np.__version__.split(".")[:2]]
# Add add some macro to lower the number of edit
# needed to the c file.
if bool(numpy_ver >= [1, 7]):
# Needed when we disable the old API, as cython
# use the old interface
preargs.append("-DNPY_ENSUREARRAY=NPY_ARRAY_ENSUREARRAY")
preargs.append("-DNPY_ENSURECOPY=NPY_ARRAY_ENSURECOPY")
preargs.append("-DNPY_ALIGNED=NPY_ARRAY_ALIGNED")
preargs.append("-DNPY_WRITEABLE=NPY_ARRAY_WRITEABLE")
preargs.append("-DNPY_UPDATE_ALL=NPY_ARRAY_UPDATE_ALL")
preargs.append("-DNPY_C_CONTIGUOUS=NPY_ARRAY_C_CONTIGUOUS")
preargs.append("-DNPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS")
cmodule.GCC_compiler.compile_str( cmodule.GCC_compiler.compile_str(
dirname, code, location=loc, preargs=preargs, hide_symbols=False dirname, code, location=loc, preargs=preargs, hide_symbols=False
) )
# Save version into the __init__.py file. # Save version into the __init__.py file.
init_py = os.path.join(loc, "__init__.py") init_py = os.path.join(loc, "__init__.py")
with open(init_py, "w") as f: with open(init_py, "w") as f:
f.write(f"_version = {version}\n") f.write(f"_version = {version}\n")
# If we just compiled the module for the first time, then it was # If we just compiled the module for the first time, then it was
# imported at the same time: we need to make sure we do not # imported at the same time. We need to make sure we do not reload
# reload the now outdated __init__.pyc below. # the now outdated __init__.pyc below.
init_pyc = os.path.join(loc, "__init__.pyc") init_pyc = os.path.join(loc, "__init__.pyc")
if os.path.isfile(init_pyc): if os.path.isfile(init_pyc):
os.remove(init_pyc) os.remove(init_pyc)
try_import() try_import()
try_reload() try_reload()
from scan_perform import scan_perform as scan_c from scan_perform import scan_perform as scan_c
assert scan_perform._version == scan_c.get_version() assert scan_perform._version == scan_c.get_version()
_logger.info(f"New version {scan_perform._version}") _logger.info(f"New version {scan_perform._version}")
# This is caused as cython use the old NumPy C-API but we use the new one. from scan_perform.scan_perform import get_version, perform # noqa: F401, E402
# To fix it completely, we would need to modify Cython to use the new API.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
from scan_perform.scan_perform import get_version, perform # noqa: F401
assert version == get_version() assert version == get_version()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论