提交 2865e546 authored 作者: Frederic's avatar Frederic

Update from code review

上级 cd442e56
from __future__ import print_function
import collections
import logging
from six.moves import StringIO
import numpy as np
import theano
......@@ -25,7 +27,7 @@ AddConfigVar('NanGuardMode.big_is_error',
in_c_key=False)
AddConfigVar('NanGuardMode.action',
"What NanGuardMode do when it find a problem",
"What NanGuardMode does when it finds a problem",
EnumStr('raise', 'warn', 'pdb'),
in_c_key=False)
......@@ -60,14 +62,15 @@ def flatten(l):
return rval
def contains_nan(arr, nd=None):
def contains_nan(arr, node=None):
"""
Test whether a numpy.ndarray contains any `np.nan` values.
Parameters
----------
arr : np.ndarray or output of any Theano op
nd : if the output of an Theano op, the node associated to it
node : None or an Apply instance.
If arr is the output of a Theano op, the node associated to it.
Returns
-------
......@@ -91,7 +94,7 @@ def contains_nan(arr, nd=None):
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
if (hasattr(theano.sandbox, 'rng_mrg') and
isinstance(
nd.op,
node.op,
# It store ints in float container
theano.sandbox.rng_mrg.GPU_mrg_uniform)):
return False
......@@ -102,14 +105,16 @@ def contains_nan(arr, nd=None):
return np.isnan(np.min(arr))
def contains_inf(arr, nd=None):
def contains_inf(arr, node=None):
"""
Test whether a numpy.ndarray contains any `np.inf` values.
Parameters
----------
arr : np.ndarray or output of any Theano op
nd : if the output of an Theano op, the node associated to it
node : None or an Apply instance.
If the output of a Theano op, the node associated to it.
Returns
-------
contains_inf : bool
......@@ -133,7 +138,7 @@ def contains_inf(arr, nd=None):
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
if (hasattr(theano.sandbox, 'rng_mrg') and
isinstance(
nd.op,
node.op,
# It store ints in float container
theano.sandbox.rng_mrg.GPU_mrg_uniform)):
return False
......@@ -247,13 +252,14 @@ class NanGuardMode(Mode):
"""
error = False
sio = StringIO()
if nan_is_error:
if contains_nan(var, nd):
logger.error('NaN detected')
print('NaN detected', file=sio)
error = True
if inf_is_error:
if contains_inf(var, nd):
logger.error('Inf detected')
print('Inf detected', file=sio)
error = True
if big_is_error:
err = False
......@@ -268,27 +274,29 @@ class NanGuardMode(Mode):
else:
err = (np.abs(var).max() > 1e10)
if err:
logger.error('Big value detected')
print('Big value detected', file=sio)
error = True
if error:
if not is_input:
logger.error("NanGuardMode found an error in the"
" output of a node in this variable:")
logger.error(theano.printing.debugprint(nd, file='str'))
print("NanGuardMode found an error in the"
" output of a node in this variable:", file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
else:
logger.error("NanGuardMode found an error in the"
" input %d of this node.")
logger.error('Node:')
logger.error(nd)
logger.error("The input variable that cause problem:")
logger.error(theano.printing.debugprint(nd, file='str'))
print("NanGuardMode found an error in the"
" input %d of this node.", file=sio)
print('Node:', file=sio)
print(nd, file=sio)
print("The input variable that cause problem:", file=sio)
print(theano.printing.debugprint(nd, file='str'), file=sio)
msg = sio.getvalue()
if config.NanGuardMode.action == 'raise':
assert False
raise AssertionError(msg)
elif config.NanGuardMode.action == 'pdb':
print(msg)
import pdb
pdb.set_trace()
elif config.NanGuardMode.action == 'warn':
pass # already printed
logger.error(msg)
def nan_check(i, node, fn):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论