提交 49590e13 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Convert NanGuardMode to use the VM linker instead of WrapLinker.

上级 69338f33
......@@ -211,7 +211,7 @@ class NanGuardMode(Mode):
assert nan_is_error or inf_is_error or big_is_error
compile_gpu_func(nan_is_error, inf_is_error, big_is_error)
def do_check_on(var, nd, f, is_input):
def do_check_on(var, nd):
"""
Checks `var` for NaNs / Infs. If detected, raises an exception
and / or prints information about `nd`, `f`, and `is_input` to
......@@ -223,11 +223,6 @@ class NanGuardMode(Mode):
The value to be checked.
nd : theano.gof.Apply
The Apply node being executed.
f : callable
The thunk for the apply node.
is_input : bool
If True, `var` is an input to `nd`.
If False, it is an output.
"""
error = False
......@@ -256,17 +251,9 @@ class NanGuardMode(Mode):
print('Big value detected', file=sio)
error = True
if error:
if not is_input:
print("NanGuardMode found an error in the"
" output of a node in this variable:", file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
else:
print("NanGuardMode found an error in an"
" input of this node.", file=sio)
print('Node:', file=sio)
print(nd, file=sio)
print("The input variable that cause problem:", file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
print("NanGuardMode found an error in the"
" output of a node in this variable:", file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
msg = sio.getvalue()
if config.NanGuardMode.action == 'raise':
raise AssertionError(msg)
......@@ -277,36 +264,26 @@ class NanGuardMode(Mode):
elif config.NanGuardMode.action == 'warn':
logger.error(msg)
def nan_check(i, node, fn):
def nan_check(node, thunk, storage_map, compute_map):
"""
Runs `fn` while checking its inputs and outputs for NaNs / Infs.
Parameters
----------
i :
Currently ignored.
TODO: determine why it is here or remove).
node : theano.gof.Apply
The Apply node currently being executed.
fn : callable
thunk : callable
The thunk to execute for this Apply node.
storage_map : dict
The mapping of variables to storage cells.
compute_map : dict
The mapping to get the computed state of the nodes. (ignored)
"""
inputs = fn.inputs
for x, var in zip(inputs, node.inputs):
# If the input is the result of computation, then we
# don't need to check it. It is already done after the
# computation.
if (var.owner is None and
getattr(var.tag, 'nan_guard_mode_check', True)):
do_check_on(x[0], node, fn, True)
fn()
outputs = fn.outputs
for x, var in zip(outputs, node.outputs):
for var in node.outputs:
if getattr(var.tag, 'nan_guard_mode_check', True):
do_check_on(x[0], node, fn, False)
do_check_on(storage_map[var][0], node)
wrap_linker = theano.gof.WrapLinker([theano.gof.OpWiseCLinker()],
nan_check)
wrap_linker = theano.gof.vm.VM_Linker(callback=nan_check)
super(NanGuardMode, self).__init__(wrap_linker,
optimizer=self.provided_optimizer)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论