提交 a2f9752f authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Switch to lock_ctx everywhere

上级 e692e389
......@@ -16,10 +16,7 @@ from theano.configdefaults import config
__all__ = [
"force_unlock",
"get_lock",
"lock",
"lock_ctx",
"release_lock",
]
......
......@@ -6,6 +6,7 @@ Driver of graph construction, optimization, and linking.
import copy
import copyreg
import logging
import os
import pickle
import time
import warnings
......@@ -16,6 +17,7 @@ import numpy as np
import theano
import theano.compile.profiling
from theano import gof
from theano.compile.compilelock import lock_ctx
from theano.compile.io import In, SymbolicInput, SymbolicOutput
from theano.compile.ops import deep_copy_op, view_op
from theano.configdefaults import config
......@@ -1359,17 +1361,13 @@ class FunctionMaker:
def optimize_graph_with_cache(self, optimizer, inputs, outputs):
# This function is not finished
import os.path
from theano.compile.compilelock import get_lock, release_lock
graph_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl")
# the inputs, outputs, and size of the graph to be optimized
inputs_new = [inp.variable for inp in inputs]
outputs_new = [out.variable for out in outputs]
size_new = len(self.fgraph.apply_nodes)
get_lock()
# Beginning of cache optimizations.
# Could be refactored in different functions.
......@@ -1480,29 +1478,30 @@ class FunctionMaker:
break
return found_graph_in_db
graph_db = load_graph_db()
print(f"loaded graph_db from {graph_db_file}, size={len(graph_db)}")
found_graph = find_same_graph_in_db(graph_db)
if found_graph:
self.fgraph = found_graph
optimizer_profile = None
else:
# this is a brand new graph, optimize it, save it to graph_db
print("graph not found in graph_db, optimizing the graph")
self.fgraph.variables = set(
gof.graph.variables(self.fgraph.inputs, self.fgraph.outputs)
)
# check_integrity parameters was added to ignore
# "excess cached variables" errors. Works that way
# but once again the error couldbe worth
# investigating.
before_opt = self.fgraph.clone(check_integrity=False)
optimizer_profile = optimizer(self.fgraph)
graph_db.update({before_opt: self.fgraph})
with open(graph_db_file, "wb") as f:
pickle.dump(graph_db, f, -1)
print("new graph saved into graph_db")
release_lock()
with lock_ctx():
graph_db = load_graph_db()
print(f"loaded graph_db from {graph_db_file}, size={len(graph_db)}")
found_graph = find_same_graph_in_db(graph_db)
if found_graph:
self.fgraph = found_graph
optimizer_profile = None
else:
# this is a brand new graph, optimize it, save it to graph_db
print("graph not found in graph_db, optimizing the graph")
self.fgraph.variables = set(
gof.graph.variables(self.fgraph.inputs, self.fgraph.outputs)
)
# check_integrity parameters was added to ignore
# "excess cached variables" errors. Works that way
# but once again the error couldbe worth
# investigating.
before_opt = self.fgraph.clone(check_integrity=False)
optimizer_profile = optimizer(self.fgraph)
graph_db.update({before_opt: self.fgraph})
with open(graph_db_file, "wb") as f:
pickle.dump(graph_db, f, -1)
print("new graph saved into graph_db")
return optimizer_profile
def __init__(
......
......@@ -11,7 +11,7 @@ from io import StringIO
import numpy as np
from theano.compile.compilelock import get_lock, release_lock
from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config
from theano.gof.callcache import CallCache
from theano.gof.graph import Constant, NoParams, io_toposort
......@@ -1614,23 +1614,21 @@ class CLinker(Linker):
preargs = self.compile_args()
# We want to compute the code without the lock
src_code = mod.code()
get_lock()
try:
_logger.debug(f"LOCATION {location}")
module = c_compiler.compile_str(
module_name=mod.code_hash,
src_code=src_code,
location=location,
include_dirs=self.header_dirs(),
lib_dirs=self.lib_dirs(),
libs=libs,
preargs=preargs,
)
except Exception as e:
e.args += (str(self.fgraph),)
raise
finally:
release_lock()
with lock_ctx():
try:
_logger.debug(f"LOCATION {location}")
module = c_compiler.compile_str(
module_name=mod.code_hash,
src_code=src_code,
location=location,
include_dirs=self.header_dirs(),
lib_dirs=self.lib_dirs(),
libs=libs,
preargs=preargs,
)
except Exception as e:
e.args += (str(self.fgraph),)
raise
return module
def get_dynamic_module(self):
......@@ -1908,78 +1906,63 @@ class OpWiseCLinker(LocalLinker):
self, profiler=None, input_storage=None, output_storage=None, storage_map=None
):
# The lock will be acquired when we compile the first
# C code. We will keep the lock until all the function
# compilation will be finished. This allow to don't
# require the lock when all c code are already compiled!
orig_n_lock = getattr(get_lock, "n_lock", 0)
try:
fgraph = self.fgraph
order = self.schedule(fgraph)
no_recycling = self.no_recycling
fgraph = self.fgraph
order = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
if self.allow_gc:
computed, last_user = gc_helper(order)
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
if self.allow_gc:
computed, last_user = gc_helper(order)
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = []
for node in order:
# make_thunk will try by default C code, otherwise
# it fall back to python.
thunks += [
node.op.make_thunk(node, storage_map, compute_map, no_recycling)
]
thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs]
for node in order:
if self.allow_gc:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (
(input in computed)
and (input not in fgraph.outputs)
and node == last_user[input]
)
]
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
thunks = []
for node in order:
# make_thunk will try by default C code, otherwise
# it fall back to python.
thunks += [node.op.make_thunk(node, storage_map, compute_map, no_recycling)]
thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs]
f = streamline(
fgraph,
thunks,
order,
post_thunk_old_storage,
no_recycling=no_recycling,
nice_errors=self.nice_errors,
)
for node in order:
if self.allow_gc:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (
(input in computed)
and (input not in fgraph.outputs)
and node == last_user[input]
)
]
)
f.allow_gc = self.allow_gc
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
f = streamline(
fgraph,
thunks,
order,
post_thunk_old_storage,
no_recycling=no_recycling,
nice_errors=self.nice_errors,
)
finally:
# Release lock on compilation directory.
if getattr(get_lock, "n_lock", 0) > orig_n_lock:
release_lock()
assert get_lock.n_lock == orig_n_lock
f.allow_gc = self.allow_gc
return (
f,
......
......@@ -2,7 +2,7 @@ import errno
import os
import sys
from theano.compile.compilelock import get_lock, release_lock
from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config
from theano.link.c import cmodule
......@@ -101,11 +101,10 @@ try:
try:
from cutils_ext.cutils_ext import * # noqa
except ImportError:
get_lock()
# Ensure no-one else is currently modifying the content of the compilation
# directory. This is important to prevent multiple processes from trying to
# compile the cutils_ext module simultaneously.
try:
with lock_ctx():
# Ensure no-one else is currently modifying the content of the compilation
# directory. This is important to prevent multiple processes from trying to
# compile the cutils_ext module simultaneously.
try:
# We must retry to import it as some other process could
# have been compiling it between the first failed import
......@@ -115,10 +114,6 @@ try:
compile_cutils()
from cutils_ext.cutils_ext import * # noqa
finally:
# Release lock on compilation directory.
release_lock()
finally:
if sys.path[0] == config.compiledir:
del sys.path[0]
......@@ -6,7 +6,7 @@ import warnings
from importlib import reload
import theano
from theano.compile.compilelock import get_lock, release_lock
from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config
from theano.link.c.cmodule import GCC_compiler
......@@ -79,8 +79,7 @@ try:
f"Extra debug information: force_compile={force_compile}, _need_reload={_need_reload}"
)
except ImportError:
get_lock()
try:
with lock_ctx():
# Maybe someone else already finished compiling it while we were
# waiting for the lock?
try:
......@@ -152,9 +151,6 @@ except ImportError:
assert lazylinker_ext._version == lazy_c.get_version()
_logger.info(f"New version {lazylinker_ext._version}")
finally:
# Release lock on compilation directory.
release_lock()
from lazylinker_ext.lazylinker_ext import * # noqa
......
......@@ -17,7 +17,7 @@ from importlib import reload
import numpy as np
import theano
from theano.compile.compilelock import get_lock, release_lock
from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config
from theano.link.c import cmodule
......@@ -50,8 +50,7 @@ try:
if version != getattr(scan_perform, "_version", None):
raise ImportError()
except ImportError:
get_lock()
try:
with lock_ctx():
# Maybe someone else already finished compiling it while we were
# waiting for the lock?
try:
......@@ -139,9 +138,6 @@ except ImportError:
assert scan_perform._version == scan_c.get_version()
_logger.info(f"New version {scan_perform._version}")
finally:
# Release lock on compilation directory.
release_lock()
# This is caused as cython use the old NumPy C-API but we use the new one.
# To fix it completely, we would need to modify Cython to use the new API.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论