提交 604709ed authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add checking for inputs in new NanGuardMode.

上级 530f442a
......@@ -71,7 +71,7 @@ def contains_nan(arr, node=None):
elif arr.size == 0:
return False
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
if (hasattr(theano.sandbox, 'rng_mrg') and
if (node and hasattr(theano.sandbox, 'rng_mrg') and
isinstance(
node.op,
# It store ints in float container
......@@ -115,7 +115,7 @@ def contains_inf(arr, node=None):
elif arr.size == 0:
return False
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
if (hasattr(theano.sandbox, 'rng_mrg') and
if (node and hasattr(theano.sandbox, 'rng_mrg') and
isinstance(
node.op,
# It store ints in float container
......@@ -251,10 +251,14 @@ class NanGuardMode(Mode):
print('Big value detected', file=sio)
error = True
if error:
print("NanGuardMode found an error in the"
" output of a node in this variable:", file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
msg = sio.getvalue()
if nd:
print("NanGuardMode found an error in the "
"output of a node in this variable:", file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
msg = sio.getvalue()
else:
print("NanGuardMode found an error in an input of the "
"graph.", file=sio)
if config.NanGuardMode.action == 'raise':
raise AssertionError(msg)
elif config.NanGuardMode.action == 'pdb':
......@@ -265,25 +269,15 @@ class NanGuardMode(Mode):
logger.error(msg)
def nan_check(node, thunk, storage_map, compute_map):
"""
Runs `fn` while checking its inputs and outputs for NaNs / Infs.
Parameters
----------
node : theano.gof.Apply
The Apply node currently being executed.
thunk : callable
The thunk to execute for this Apply node.
storage_map : dict
The mapping of variables to storage cells.
compute_map : dict
The mapping to get the computed state of the nodes. (ignored)
"""
for var in node.outputs:
if getattr(var.tag, 'nan_guard_mode_check', True):
do_check_on(storage_map[var][0], node)
wrap_linker = theano.gof.vm.VM_Linker(callback=nan_check)
def nan_check_input(var, value):
if getattr(var.tag, 'nan_guard_mode_check', True):
do_check_on(value, None)
wrap_linker = theano.gof.vm.VM_Linker(callback=nan_check,
callback_input=nan_check_input)
super(NanGuardMode, self).__init__(wrap_linker,
optimizer=self.provided_optimizer)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论