提交 6df026b4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5550 from nouiz/davidbau-nanguardfix

Fix NaNGuardMode to permit None variables
...@@ -23,6 +23,37 @@ except ImportError: ...@@ -23,6 +23,37 @@ except ImportError:
logger = logging.getLogger("theano.compile.nanguardmode") logger = logging.getLogger("theano.compile.nanguardmode")
def _is_numeric_value(arr, var):
"""
Checks a variable against non-numeric types such as types, slices,
empty arrays, and None, that need not be checked for NaN and Inf values.
Parameters
----------
arr : the data of that correspond to any Theano Variable
var : The corresponding Theano variable
Returns
-------
is_non_numeric : bool
`True` the value is non-numeric.
"""
if isinstance(arr, theano.gof.type._cdata_type):
return False
elif isinstance(arr, np.random.mtrand.RandomState):
return False
elif var and getattr(var.tag, 'is_rng', False):
return False
elif isinstance(arr, slice):
return False
elif arr is None:
return False
elif arr.size == 0:
return False
return True
def flatten(l): def flatten(l):
""" """
Turns a nested graph of lists/tuples/other objects into a list of objects. Turns a nested graph of lists/tuples/other objects into a list of objects.
...@@ -74,16 +105,7 @@ def contains_nan(arr, node=None, var=None): ...@@ -74,16 +105,7 @@ def contains_nan(arr, node=None, var=None):
construction of a boolean array with the same shape as the input array. construction of a boolean array with the same shape as the input array.
""" """
# This should be a whitelist instead of a blacklist if not _is_numeric_value(arr, var):
if isinstance(arr, theano.gof.type._cdata_type):
return False
elif isinstance(arr, np.random.mtrand.RandomState):
return False
elif var and getattr(var.tag, 'is_rng', False):
return False
elif isinstance(arr, slice):
return False
elif arr.size == 0:
return False return False
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray): elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
if (node and hasattr(theano.sandbox, 'rng_mrg') and if (node and hasattr(theano.sandbox, 'rng_mrg') and
...@@ -126,15 +148,7 @@ def contains_inf(arr, node=None, var=None): ...@@ -126,15 +148,7 @@ def contains_inf(arr, node=None, var=None):
boolean array with the same shape as the input array. boolean array with the same shape as the input array.
""" """
if isinstance(arr, theano.gof.type._cdata_type): if not _is_numeric_value(arr, var):
return False
elif isinstance(arr, np.random.mtrand.RandomState):
return False
elif var and getattr(var.tag, 'is_rng', False):
return False
elif isinstance(arr, slice):
return False
elif arr.size == 0:
return False return False
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray): elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
if (node and hasattr(theano.sandbox, 'rng_mrg') and if (node and hasattr(theano.sandbox, 'rng_mrg') and
...@@ -288,13 +302,7 @@ class NanGuardMode(Mode): ...@@ -288,13 +302,7 @@ class NanGuardMode(Mode):
error = True error = True
if big_is_error: if big_is_error:
err = False err = False
if isinstance(value, theano.gof.type._cdata_type): if not _is_numeric_value(value, var):
err = False
elif isinstance(value, np.random.mtrand.RandomState):
err = False
elif isinstance(value, slice):
err = False
elif value.size == 0:
err = False err = False
elif cuda.cuda_available and isinstance(value, cuda.CudaNdarray): elif cuda.cuda_available and isinstance(value, cuda.CudaNdarray):
compile_gpu_func(False, False, True) compile_gpu_func(False, False, True)
......
...@@ -56,7 +56,7 @@ def test_NanGuardMode(): ...@@ -56,7 +56,7 @@ def test_NanGuardMode():
np.asarray(1e20).astype(theano.config.floatX), (3, 4, 5)) np.asarray(1e20).astype(theano.config.floatX), (3, 4, 5))
x = T.tensor3() x = T.tensor3()
y = x[:, T.arange(2), T.arange(2)] y = x[:, T.arange(2), T.arange(2), None]
fun = theano.function( fun = theano.function(
[x], y, [x], y,
mode=NanGuardMode(nan_is_error=True, inf_is_error=True) mode=NanGuardMode(nan_is_error=True, inf_is_error=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论