提交 1d5236be authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Outsource exceptions to reduce dependency on cmodule

上级 0e3a8128
......@@ -28,6 +28,7 @@ from theano.configdefaults import gcc_version_str, local_bitwidth
# we will abuse the lockfile mechanism when reading and writing the registry
from theano.gof import compilelock
from theano.gof.utils import flatten, hash_from_code
from theano.link.c.exceptions import MissingGXX
from theano.utils import output_subprocess_Popen, subprocess_Popen
......@@ -45,14 +46,6 @@ METH_NOARGS = "METH_NOARGS"
import_time = 0
class MissingGXX(Exception):
"""
This error is raised when we try to generate c code,
but g++ is not available.
"""
def debug_counter(name, every=1):
"""
Debug counter to know how often we go through some piece of code.
......
class MissingGXX(Exception):
"""
This error is raised when we try to generate c code,
but g++ is not available.
"""
......@@ -12,8 +12,11 @@ import time
import warnings
from collections import defaultdict
import theano.link.c.cmodule
from theano import config, link
from theano import config
from theano.gof import Constant, Variable
from theano.link.basic import Container, LocalLinker
from theano.link.c.exceptions import MissingGXX
from theano.link.utils import gc_helper, map_storage, raise_with_op
logger = logging.getLogger(__name__)
......@@ -55,7 +58,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
)
ins = node.inputs[idx_v[0]]
if ins is not None:
assert isinstance(ins, theano.Variable)
assert isinstance(ins, Variable)
origin = view_of.get(ins, ins)
view_of[out] = origin
viewed_by[origin].append(out)
......@@ -96,7 +99,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if (
not viewed_by[origin]
and origin not in fgraph.inputs
and not isinstance(origin, theano.Constant)
and not isinstance(origin, Constant)
):
# where gc
for i in range(idx + 1, len(order)):
......@@ -255,7 +258,7 @@ class Loop(VM):
self.call_counts[i] += 1
self.call_times[i] += t1 - t0
except Exception:
link.raise_with_op(self.fgraph, node, thunk)
raise_with_op(self.fgraph, node, thunk)
else:
for cont in self.pre_call_clear:
cont[0] = None
......@@ -263,7 +266,7 @@ class Loop(VM):
for thunk, node in zip(self.thunks, self.nodes):
thunk()
except Exception:
link.raise_with_op(self.fgraph, node, thunk)
raise_with_op(self.fgraph, node, thunk)
class LoopGC(VM):
......@@ -299,7 +302,7 @@ class LoopGC(VM):
old_s[0] = None
i += 1
except Exception:
link.raise_with_op(self.fgraph, node, thunk)
raise_with_op(self.fgraph, node, thunk)
else:
for cont in self.pre_call_clear:
cont[0] = None
......@@ -311,7 +314,7 @@ class LoopGC(VM):
for old_s in old_storage:
old_s[0] = None
except Exception:
link.raise_with_op(self.fgraph, node, thunk)
raise_with_op(self.fgraph, node, thunk)
class Stack(VM):
......@@ -539,7 +542,7 @@ class Stack(VM):
off = getattr(o[0], "offset", "")
self.variable_offset[var] = off
except Exception:
link.raise_with_op(
raise_with_op(
self.fgraph,
current_apply,
self.thunks[self.node_idx[current_apply]],
......@@ -611,7 +614,7 @@ class Stack(VM):
self.call_times[current_idx] += dt
except Exception:
link.raise_with_op(
raise_with_op(
self.fgraph,
current_apply,
self.thunks[self.node_idx[current_apply]],
......@@ -693,8 +696,8 @@ class Stack(VM):
try:
# If cxx is explicitely set to an empty string, we do not want to import neither lazylinker C code
# nor lazylinker compiled C code from cache.
if not theano.config.cxx:
raise theano.link.c.cmodule.MissingGXX(
if not config.cxx:
raise MissingGXX(
"lazylinker will not be imported if theano.config.cxx is not set."
)
from theano.link.c import lazylinker_c
......@@ -708,7 +711,7 @@ try:
except ImportError:
pass
except (OSError, theano.link.c.cmodule.MissingGXX) as e:
except (OSError, MissingGXX) as e:
# OSError happens when g++ is not installed. In that case, we
# already changed the default linker to something else then CVM.
# Currently this is the py linker.
......@@ -716,7 +719,7 @@ except (OSError, theano.link.c.cmodule.MissingGXX) as e:
assert not config._config_var_dict["linker"].default.startswith("cvm"), e
class VM_Linker(link.LocalLinker):
class VM_Linker(LocalLinker):
"""
Class that satisfies the Linker interface by acting as a VM factory.
......@@ -774,7 +777,7 @@ class VM_Linker(link.LocalLinker):
self.callback_input = callback_input
self.lazy = lazy
if c_thunks is None:
c_thunks = bool(theano.config.cxx)
c_thunks = bool(config.cxx)
self.c_thunks = c_thunks
self.allow_partial_eval = allow_partial_eval
self.updated_vars = {}
......@@ -1111,7 +1114,7 @@ class VM_Linker(link.LocalLinker):
fgraph = self.fgraph
order = self.schedule(fgraph)
input_storage, output_storage, storage_map = link.map_storage(
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
compute_map = {}
......@@ -1185,7 +1188,7 @@ class VM_Linker(link.LocalLinker):
for pair in reallocated_info.values():
storage_map[pair[1]] = storage_map[pair[0]]
computed, last_user = link.gc_helper(order)
computed, last_user = gc_helper(order)
if self.allow_gc:
post_thunk_clear = []
for node in order:
......@@ -1220,11 +1223,11 @@ class VM_Linker(link.LocalLinker):
return (
vm,
[
link.Container(input, storage)
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
link.Container(output, storage, readonly=True)
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
......
......@@ -64,6 +64,7 @@ from theano.gof.graph import equal_computations, io_connection_pattern
from theano.gof.toolbox import NoOutputFromInplace
from theano.gradient import DisconnectedType, NullType, grad_undefined
from theano.link.c.cc import CLinker
from theano.link.c.exceptions import MissingGXX
from theano.link.utils import raise_with_op
from theano.scan.utils import Validator, forced_replace, hash_listsDictsTuples, safe_new
from theano.tensor import TensorType, as_tensor_variable
......@@ -974,7 +975,7 @@ class Scan(PureOp):
try:
if impl == "py":
raise theano.link.c.cmodule.MissingGXX
raise MissingGXX
cython_mintaps = np.asarray(self.mintaps, dtype="int32")
cython_tap_array_len = np.asarray(
[len(x) for x in self.tap_array], dtype="int32"
......@@ -1051,7 +1052,7 @@ class Scan(PureOp):
node,
)
except (ImportError, theano.link.c.cmodule.MissingGXX):
except (ImportError, MissingGXX):
p = self.execute
# default arguments are stored in the closure of `rval`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论