提交 49590e13 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Convert NanGuardMode to use the VM linker instead of WrapLinker.

上级 69338f33
...@@ -211,7 +211,7 @@ class NanGuardMode(Mode): ...@@ -211,7 +211,7 @@ class NanGuardMode(Mode):
assert nan_is_error or inf_is_error or big_is_error assert nan_is_error or inf_is_error or big_is_error
compile_gpu_func(nan_is_error, inf_is_error, big_is_error) compile_gpu_func(nan_is_error, inf_is_error, big_is_error)
def do_check_on(var, nd, f, is_input): def do_check_on(var, nd):
""" """
Checks `var` for NaNs / Infs. If detected, raises an exception Checks `var` for NaNs / Infs. If detected, raises an exception
and / or prints information about `nd`, `f`, and `is_input` to and / or prints information about `nd`, `f`, and `is_input` to
...@@ -223,11 +223,6 @@ class NanGuardMode(Mode): ...@@ -223,11 +223,6 @@ class NanGuardMode(Mode):
The value to be checked. The value to be checked.
nd : theano.gof.Apply nd : theano.gof.Apply
The Apply node being executed. The Apply node being executed.
f : callable
The thunk for the apply node.
is_input : bool
If True, `var` is an input to `nd`.
If False, it is an output.
""" """
error = False error = False
...@@ -256,17 +251,9 @@ class NanGuardMode(Mode): ...@@ -256,17 +251,9 @@ class NanGuardMode(Mode):
print('Big value detected', file=sio) print('Big value detected', file=sio)
error = True error = True
if error: if error:
if not is_input: print("NanGuardMode found an error in the"
print("NanGuardMode found an error in the" " output of a node in this variable:", file=sio)
" output of a node in this variable:", file=sio) print(theano.printing.debugprint(nd, file='str'), file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
else:
print("NanGuardMode found an error in an"
" input of this node.", file=sio)
print('Node:', file=sio)
print(nd, file=sio)
print("The input variable that cause problem:", file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
msg = sio.getvalue() msg = sio.getvalue()
if config.NanGuardMode.action == 'raise': if config.NanGuardMode.action == 'raise':
raise AssertionError(msg) raise AssertionError(msg)
...@@ -277,36 +264,26 @@ class NanGuardMode(Mode): ...@@ -277,36 +264,26 @@ class NanGuardMode(Mode):
elif config.NanGuardMode.action == 'warn': elif config.NanGuardMode.action == 'warn':
logger.error(msg) logger.error(msg)
def nan_check(i, node, fn): def nan_check(node, thunk, storage_map, compute_map):
""" """
Runs `fn` while checking its inputs and outputs for NaNs / Infs. Runs `fn` while checking its inputs and outputs for NaNs / Infs.
Parameters Parameters
---------- ----------
i :
Currently ignored.
TODO: determine why it is here or remove).
node : theano.gof.Apply node : theano.gof.Apply
The Apply node currently being executed. The Apply node currently being executed.
fn : callable thunk : callable
The thunk to execute for this Apply node. The thunk to execute for this Apply node.
storage_map : dict
The mapping of variables to storage cells.
compute_map : dict
The mapping to get the computed state of the nodes. (ignored)
""" """
inputs = fn.inputs for var in node.outputs:
for x, var in zip(inputs, node.inputs):
# If the input is the result of computation, then we
# don't need to check it. It is already done after the
# computation.
if (var.owner is None and
getattr(var.tag, 'nan_guard_mode_check', True)):
do_check_on(x[0], node, fn, True)
fn()
outputs = fn.outputs
for x, var in zip(outputs, node.outputs):
if getattr(var.tag, 'nan_guard_mode_check', True): if getattr(var.tag, 'nan_guard_mode_check', True):
do_check_on(x[0], node, fn, False) do_check_on(storage_map[var][0], node)
wrap_linker = theano.gof.WrapLinker([theano.gof.OpWiseCLinker()], wrap_linker = theano.gof.vm.VM_Linker(callback=nan_check)
nan_check)
super(NanGuardMode, self).__init__(wrap_linker, super(NanGuardMode, self).__init__(wrap_linker,
optimizer=self.provided_optimizer) optimizer=self.provided_optimizer)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论