提交 f97b10df authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix test crash and now NanGuardMode print the stack trace on inputs of the graph.

上级 444fa022
......@@ -215,44 +215,47 @@ class NanGuardMode(Mode):
assert nan_is_error or inf_is_error or big_is_error
compile_gpu_func(nan_is_error, inf_is_error, big_is_error)
def do_check_on(var, nd):
def do_check_on(value, nd, var=None):
"""
Checks `var` for NaNs / Infs. If detected, raises an exception
Checks `value` for NaNs / Infs. If detected, raises an exception
and / or prints information about `nd`, `f`, and `is_input` to
help the user determine the cause of the invalid values.
Parameters
----------
var : numpy.ndarray
value : numpy.ndarray
The value to be checked.
nd : theano.gof.Apply
The Apply node being executed.
var : theano.gof.Variable
Not used if nd is there. Otherwise, used to print the stack
trace for inputs of the graph.
"""
error = False
sio = StringIO()
if nan_is_error:
if contains_nan(var, nd):
if contains_nan(value, nd):
print('NaN detected', file=sio)
error = True
if inf_is_error:
if contains_inf(var, nd):
if contains_inf(value, nd):
print('Inf detected', file=sio)
error = True
if big_is_error:
err = False
if isinstance(var, theano.gof.type.CDataType._cdata_type):
if isinstance(value, theano.gof.type.CDataType._cdata_type):
err = False
elif isinstance(var, np.random.mtrand.RandomState):
elif isinstance(value, np.random.mtrand.RandomState):
err = False
elif isinstance(var, slice):
elif isinstance(value, slice):
err = False
elif var.size == 0:
elif value.size == 0:
err = False
elif cuda.cuda_available and isinstance(var, cuda.CudaNdarray):
err = (f_gpuabsmax(var.reshape(var.size)) > 1e10)
elif cuda.cuda_available and isinstance(value, cuda.CudaNdarray):
err = (f_gpuabsmax(value.reshape(value.size)) > 1e10)
else:
err = (np.abs(var).max() > 1e10)
err = (np.abs(value).max() > 1e10)
if err:
print('Big value detected', file=sio)
error = True
......@@ -265,8 +268,10 @@ class NanGuardMode(Mode):
print("NanGuardMode found an error in an input of the "
"graph.", file=sio)
# Add the stack trace
print(theano.gof.utils.get_variable_trace_string(
nd.outputs[0]), file=sio)
if nd:
var = nd.outputs[0]
print(theano.gof.utils.get_variable_trace_string(var),
file=sio)
msg = sio.getvalue()
if config.NanGuardMode.action == 'raise':
raise AssertionError(msg)
......@@ -284,7 +289,7 @@ class NanGuardMode(Mode):
def nan_check_input(var, value):
if getattr(var.tag, 'nan_guard_mode_check', True):
do_check_on(value, None)
do_check_on(value, None, var=var)
wrap_linker = theano.gof.vm.VM_Linker(callback=nan_check,
callback_input=nan_check_input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论