提交 3220e103 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4808 from nouiz/nanguardmode

Add the user stack trace printing in nanguardmode
...@@ -215,44 +215,47 @@ class NanGuardMode(Mode): ...@@ -215,44 +215,47 @@ class NanGuardMode(Mode):
assert nan_is_error or inf_is_error or big_is_error assert nan_is_error or inf_is_error or big_is_error
compile_gpu_func(nan_is_error, inf_is_error, big_is_error) compile_gpu_func(nan_is_error, inf_is_error, big_is_error)
def do_check_on(var, nd): def do_check_on(value, nd, var=None):
""" """
Checks `var` for NaNs / Infs. If detected, raises an exception Checks `value` for NaNs / Infs. If detected, raises an exception
and / or prints information about `nd`, `f`, and `is_input` to and / or prints information about `nd`, `f`, and `is_input` to
help the user determine the cause of the invalid values. help the user determine the cause of the invalid values.
Parameters Parameters
---------- ----------
var : numpy.ndarray value : numpy.ndarray
The value to be checked. The value to be checked.
nd : theano.gof.Apply nd : theano.gof.Apply
The Apply node being executed. The Apply node being executed.
var : theano.gof.Variable
Not used if nd is there. Otherwise, used to print the stack
trace for inputs of the graph.
""" """
error = False error = False
sio = StringIO() sio = StringIO()
if nan_is_error: if nan_is_error:
if contains_nan(var, nd): if contains_nan(value, nd):
print('NaN detected', file=sio) print('NaN detected', file=sio)
error = True error = True
if inf_is_error: if inf_is_error:
if contains_inf(var, nd): if contains_inf(value, nd):
print('Inf detected', file=sio) print('Inf detected', file=sio)
error = True error = True
if big_is_error: if big_is_error:
err = False err = False
if isinstance(var, theano.gof.type.CDataType._cdata_type): if isinstance(value, theano.gof.type.CDataType._cdata_type):
err = False err = False
elif isinstance(var, np.random.mtrand.RandomState): elif isinstance(value, np.random.mtrand.RandomState):
err = False err = False
elif isinstance(var, slice): elif isinstance(value, slice):
err = False err = False
elif var.size == 0: elif value.size == 0:
err = False err = False
elif cuda.cuda_available and isinstance(var, cuda.CudaNdarray): elif cuda.cuda_available and isinstance(value, cuda.CudaNdarray):
err = (f_gpuabsmax(var.reshape(var.size)) > 1e10) err = (f_gpuabsmax(value.reshape(value.size)) > 1e10)
else: else:
err = (np.abs(var).max() > 1e10) err = (np.abs(value).max() > 1e10)
if err: if err:
print('Big value detected', file=sio) print('Big value detected', file=sio)
error = True error = True
...@@ -264,6 +267,11 @@ class NanGuardMode(Mode): ...@@ -264,6 +267,11 @@ class NanGuardMode(Mode):
else: else:
print("NanGuardMode found an error in an input of the " print("NanGuardMode found an error in an input of the "
"graph.", file=sio) "graph.", file=sio)
# Add the stack trace
if nd:
var = nd.outputs[0]
print(theano.gof.utils.get_variable_trace_string(var),
file=sio)
msg = sio.getvalue() msg = sio.getvalue()
if config.NanGuardMode.action == 'raise': if config.NanGuardMode.action == 'raise':
raise AssertionError(msg) raise AssertionError(msg)
...@@ -281,7 +289,7 @@ class NanGuardMode(Mode): ...@@ -281,7 +289,7 @@ class NanGuardMode(Mode):
def nan_check_input(var, value): def nan_check_input(var, value):
if getattr(var.tag, 'nan_guard_mode_check', True): if getattr(var.tag, 'nan_guard_mode_check', True):
do_check_on(value, None) do_check_on(value, None, var=var)
wrap_linker = theano.gof.vm.VM_Linker(callback=nan_check, wrap_linker = theano.gof.vm.VM_Linker(callback=nan_check,
callback_input=nan_check_input) callback_input=nan_check_input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论