提交 a2f9752f authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Switch to lock_ctx everywhere

上级 e692e389
...@@ -16,10 +16,7 @@ from theano.configdefaults import config ...@@ -16,10 +16,7 @@ from theano.configdefaults import config
__all__ = [ __all__ = [
"force_unlock", "force_unlock",
"get_lock",
"lock",
"lock_ctx", "lock_ctx",
"release_lock",
] ]
......
...@@ -6,6 +6,7 @@ Driver of graph construction, optimization, and linking. ...@@ -6,6 +6,7 @@ Driver of graph construction, optimization, and linking.
import copy import copy
import copyreg import copyreg
import logging import logging
import os
import pickle import pickle
import time import time
import warnings import warnings
...@@ -16,6 +17,7 @@ import numpy as np ...@@ -16,6 +17,7 @@ import numpy as np
import theano import theano
import theano.compile.profiling import theano.compile.profiling
from theano import gof from theano import gof
from theano.compile.compilelock import lock_ctx
from theano.compile.io import In, SymbolicInput, SymbolicOutput from theano.compile.io import In, SymbolicInput, SymbolicOutput
from theano.compile.ops import deep_copy_op, view_op from theano.compile.ops import deep_copy_op, view_op
from theano.configdefaults import config from theano.configdefaults import config
...@@ -1359,17 +1361,13 @@ class FunctionMaker: ...@@ -1359,17 +1361,13 @@ class FunctionMaker:
def optimize_graph_with_cache(self, optimizer, inputs, outputs): def optimize_graph_with_cache(self, optimizer, inputs, outputs):
# This function is not finished # This function is not finished
import os.path
from theano.compile.compilelock import get_lock, release_lock
graph_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl") graph_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl")
# the inputs, outputs, and size of the graph to be optimized # the inputs, outputs, and size of the graph to be optimized
inputs_new = [inp.variable for inp in inputs] inputs_new = [inp.variable for inp in inputs]
outputs_new = [out.variable for out in outputs] outputs_new = [out.variable for out in outputs]
size_new = len(self.fgraph.apply_nodes) size_new = len(self.fgraph.apply_nodes)
get_lock()
# Beginning of cache optimizations. # Beginning of cache optimizations.
# Could be refactored in different functions. # Could be refactored in different functions.
...@@ -1480,6 +1478,7 @@ class FunctionMaker: ...@@ -1480,6 +1478,7 @@ class FunctionMaker:
break break
return found_graph_in_db return found_graph_in_db
with lock_ctx():
graph_db = load_graph_db() graph_db = load_graph_db()
print(f"loaded graph_db from {graph_db_file}, size={len(graph_db)}") print(f"loaded graph_db from {graph_db_file}, size={len(graph_db)}")
found_graph = find_same_graph_in_db(graph_db) found_graph = find_same_graph_in_db(graph_db)
...@@ -1502,7 +1501,7 @@ class FunctionMaker: ...@@ -1502,7 +1501,7 @@ class FunctionMaker:
with open(graph_db_file, "wb") as f: with open(graph_db_file, "wb") as f:
pickle.dump(graph_db, f, -1) pickle.dump(graph_db, f, -1)
print("new graph saved into graph_db") print("new graph saved into graph_db")
release_lock()
return optimizer_profile return optimizer_profile
def __init__( def __init__(
......
...@@ -11,7 +11,7 @@ from io import StringIO ...@@ -11,7 +11,7 @@ from io import StringIO
import numpy as np import numpy as np
from theano.compile.compilelock import get_lock, release_lock from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.callcache import CallCache from theano.gof.callcache import CallCache
from theano.gof.graph import Constant, NoParams, io_toposort from theano.gof.graph import Constant, NoParams, io_toposort
...@@ -1614,7 +1614,7 @@ class CLinker(Linker): ...@@ -1614,7 +1614,7 @@ class CLinker(Linker):
preargs = self.compile_args() preargs = self.compile_args()
# We want to compute the code without the lock # We want to compute the code without the lock
src_code = mod.code() src_code = mod.code()
get_lock() with lock_ctx():
try: try:
_logger.debug(f"LOCATION {location}") _logger.debug(f"LOCATION {location}")
module = c_compiler.compile_str( module = c_compiler.compile_str(
...@@ -1629,8 +1629,6 @@ class CLinker(Linker): ...@@ -1629,8 +1629,6 @@ class CLinker(Linker):
except Exception as e: except Exception as e:
e.args += (str(self.fgraph),) e.args += (str(self.fgraph),)
raise raise
finally:
release_lock()
return module return module
def get_dynamic_module(self): def get_dynamic_module(self):
...@@ -1908,13 +1906,6 @@ class OpWiseCLinker(LocalLinker): ...@@ -1908,13 +1906,6 @@ class OpWiseCLinker(LocalLinker):
self, profiler=None, input_storage=None, output_storage=None, storage_map=None self, profiler=None, input_storage=None, output_storage=None, storage_map=None
): ):
# The lock will be acquired when we compile the first
# C code. We will keep the lock until all the function
# compilation will be finished. This allow to don't
# require the lock when all c code are already compiled!
orig_n_lock = getattr(get_lock, "n_lock", 0)
try:
fgraph = self.fgraph fgraph = self.fgraph
order = self.schedule(fgraph) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
...@@ -1936,9 +1927,7 @@ class OpWiseCLinker(LocalLinker): ...@@ -1936,9 +1927,7 @@ class OpWiseCLinker(LocalLinker):
for node in order: for node in order:
# make_thunk will try by default C code, otherwise # make_thunk will try by default C code, otherwise
# it fall back to python. # it fall back to python.
thunks += [ thunks += [node.op.make_thunk(node, storage_map, compute_map, no_recycling)]
node.op.make_thunk(node, storage_map, compute_map, no_recycling)
]
thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs]
...@@ -1975,12 +1964,6 @@ class OpWiseCLinker(LocalLinker): ...@@ -1975,12 +1964,6 @@ class OpWiseCLinker(LocalLinker):
f.allow_gc = self.allow_gc f.allow_gc = self.allow_gc
finally:
# Release lock on compilation directory.
if getattr(get_lock, "n_lock", 0) > orig_n_lock:
release_lock()
assert get_lock.n_lock == orig_n_lock
return ( return (
f, f,
[ [
......
...@@ -2,7 +2,7 @@ import errno ...@@ -2,7 +2,7 @@ import errno
import os import os
import sys import sys
from theano.compile.compilelock import get_lock, release_lock from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config from theano.configdefaults import config
from theano.link.c import cmodule from theano.link.c import cmodule
...@@ -101,11 +101,10 @@ try: ...@@ -101,11 +101,10 @@ try:
try: try:
from cutils_ext.cutils_ext import * # noqa from cutils_ext.cutils_ext import * # noqa
except ImportError: except ImportError:
get_lock() with lock_ctx():
# Ensure no-one else is currently modifying the content of the compilation # Ensure no-one else is currently modifying the content of the compilation
# directory. This is important to prevent multiple processes from trying to # directory. This is important to prevent multiple processes from trying to
# compile the cutils_ext module simultaneously. # compile the cutils_ext module simultaneously.
try:
try: try:
# We must retry to import it as some other process could # We must retry to import it as some other process could
# have been compiling it between the first failed import # have been compiling it between the first failed import
...@@ -115,10 +114,6 @@ try: ...@@ -115,10 +114,6 @@ try:
compile_cutils() compile_cutils()
from cutils_ext.cutils_ext import * # noqa from cutils_ext.cutils_ext import * # noqa
finally:
# Release lock on compilation directory.
release_lock()
finally: finally:
if sys.path[0] == config.compiledir: if sys.path[0] == config.compiledir:
del sys.path[0] del sys.path[0]
...@@ -6,7 +6,7 @@ import warnings ...@@ -6,7 +6,7 @@ import warnings
from importlib import reload from importlib import reload
import theano import theano
from theano.compile.compilelock import get_lock, release_lock from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config from theano.configdefaults import config
from theano.link.c.cmodule import GCC_compiler from theano.link.c.cmodule import GCC_compiler
...@@ -79,8 +79,7 @@ try: ...@@ -79,8 +79,7 @@ try:
f"Extra debug information: force_compile={force_compile}, _need_reload={_need_reload}" f"Extra debug information: force_compile={force_compile}, _need_reload={_need_reload}"
) )
except ImportError: except ImportError:
get_lock() with lock_ctx():
try:
# Maybe someone else already finished compiling it while we were # Maybe someone else already finished compiling it while we were
# waiting for the lock? # waiting for the lock?
try: try:
...@@ -152,9 +151,6 @@ except ImportError: ...@@ -152,9 +151,6 @@ except ImportError:
assert lazylinker_ext._version == lazy_c.get_version() assert lazylinker_ext._version == lazy_c.get_version()
_logger.info(f"New version {lazylinker_ext._version}") _logger.info(f"New version {lazylinker_ext._version}")
finally:
# Release lock on compilation directory.
release_lock()
from lazylinker_ext.lazylinker_ext import * # noqa from lazylinker_ext.lazylinker_ext import * # noqa
......
...@@ -17,7 +17,7 @@ from importlib import reload ...@@ -17,7 +17,7 @@ from importlib import reload
import numpy as np import numpy as np
import theano import theano
from theano.compile.compilelock import get_lock, release_lock from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config from theano.configdefaults import config
from theano.link.c import cmodule from theano.link.c import cmodule
...@@ -50,8 +50,7 @@ try: ...@@ -50,8 +50,7 @@ try:
if version != getattr(scan_perform, "_version", None): if version != getattr(scan_perform, "_version", None):
raise ImportError() raise ImportError()
except ImportError: except ImportError:
get_lock() with lock_ctx():
try:
# Maybe someone else already finished compiling it while we were # Maybe someone else already finished compiling it while we were
# waiting for the lock? # waiting for the lock?
try: try:
...@@ -139,9 +138,6 @@ except ImportError: ...@@ -139,9 +138,6 @@ except ImportError:
assert scan_perform._version == scan_c.get_version() assert scan_perform._version == scan_c.get_version()
_logger.info(f"New version {scan_perform._version}") _logger.info(f"New version {scan_perform._version}")
finally:
# Release lock on compilation directory.
release_lock()
# This is caused as cython use the old NumPy C-API but we use the new one. # This is caused as cython use the old NumPy C-API but we use the new one.
# To fix it completely, we would need to modify Cython to use the new API. # To fix it completely, we would need to modify Cython to use the new API.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论