提交 777caa57 authored 作者: Frederic's avatar Frederic

Small fix and refactoring

上级 e4404a6f
......@@ -55,13 +55,14 @@ def flatten(l):
return rval
def contains_nan(arr):
def contains_nan(arr, nd=None):
"""
Test whether a numpy.ndarray contains any `np.nan` values.
Parameters
----------
arr : np.ndarray
arr : np.ndarray or output of any Theano op
nd : if the output of an Theano op, the node associated to it
Returns
-------
......@@ -83,25 +84,27 @@ def contains_nan(arr):
elif arr.size == 0:
return False
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
compile_gpu_func(True, False, False)
if (not hasattr(theano.sandbox, 'rng_mrg') or
not isinstance(
if (hasattr(theano.sandbox, 'rng_mrg') and
isinstance(
nd.op,
# It store ints in float container
theano.sandbox.rng_mrg.GPU_mrg_uniform)):
return False
else:
compile_gpu_func(True, False, False)
return np.isnan(f_gpumin(arr.reshape(arr.size)))
return np.isnan(np.min(arr))
def contains_inf(arr):
def contains_inf(arr, nd=None):
"""
Test whether a numpy.ndarray contains any `np.inf` values.
Parameters
----------
arr : np.ndarray
arr : np.ndarray or output of any Theano op
nd : if the output of an Theano op, the node associated to it
Returns
-------
contains_inf : bool
......@@ -123,9 +126,16 @@ def contains_inf(arr):
elif arr.size == 0:
return False
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
compile_gpu_func(False, True, False)
return (np.isinf(f_gpumin(arr.reshape(arr.size))) or
np.isinf(f_gpumax(arr.reshape(arr.size))))
if (hasattr(theano.sandbox, 'rng_mrg') and
isinstance(
nd.op,
# It store ints in float container
theano.sandbox.rng_mrg.GPU_mrg_uniform)):
return False
else:
compile_gpu_func(False, True, False)
return (np.isinf(f_gpumin(arr.reshape(arr.size))) or
np.isinf(f_gpumax(arr.reshape(arr.size))))
return np.isinf(np.nanmax(arr)) or np.isinf(np.nanmin(arr))
......@@ -233,11 +243,11 @@ class NanGuardMode(Mode):
"""
error = False
if nan_is_error:
if contains_nan(var):
if contains_nan(var, nd):
logger.error('NaN detected')
error = True
if inf_is_error:
if contains_inf(var):
if contains_inf(var, nd):
logger.error('Inf detected')
error = True
if big_is_error:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论