提交 2865e546 authored 作者: Frederic's avatar Frederic

Update from code review

上级 cd442e56
from __future__ import print_function
import collections import collections
import logging import logging
from six.moves import StringIO
import numpy as np import numpy as np
import theano import theano
...@@ -25,7 +27,7 @@ AddConfigVar('NanGuardMode.big_is_error', ...@@ -25,7 +27,7 @@ AddConfigVar('NanGuardMode.big_is_error',
in_c_key=False) in_c_key=False)
AddConfigVar('NanGuardMode.action', AddConfigVar('NanGuardMode.action',
"What NanGuardMode do when it find a problem", "What NanGuardMode does when it finds a problem",
EnumStr('raise', 'warn', 'pdb'), EnumStr('raise', 'warn', 'pdb'),
in_c_key=False) in_c_key=False)
...@@ -60,14 +62,15 @@ def flatten(l): ...@@ -60,14 +62,15 @@ def flatten(l):
return rval return rval
def contains_nan(arr, nd=None): def contains_nan(arr, node=None):
""" """
Test whether a numpy.ndarray contains any `np.nan` values. Test whether a numpy.ndarray contains any `np.nan` values.
Parameters Parameters
---------- ----------
arr : np.ndarray or output of any Theano op arr : np.ndarray or output of any Theano op
nd : if the output of an Theano op, the node associated to it node : None or an Apply instance.
If arr is the output of a Theano op, the node associated to it.
Returns Returns
------- -------
...@@ -91,7 +94,7 @@ def contains_nan(arr, nd=None): ...@@ -91,7 +94,7 @@ def contains_nan(arr, nd=None):
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray): elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
if (hasattr(theano.sandbox, 'rng_mrg') and if (hasattr(theano.sandbox, 'rng_mrg') and
isinstance( isinstance(
nd.op, node.op,
# It store ints in float container # It store ints in float container
theano.sandbox.rng_mrg.GPU_mrg_uniform)): theano.sandbox.rng_mrg.GPU_mrg_uniform)):
return False return False
...@@ -102,14 +105,16 @@ def contains_nan(arr, nd=None): ...@@ -102,14 +105,16 @@ def contains_nan(arr, nd=None):
return np.isnan(np.min(arr)) return np.isnan(np.min(arr))
def contains_inf(arr, nd=None): def contains_inf(arr, node=None):
""" """
Test whether a numpy.ndarray contains any `np.inf` values. Test whether a numpy.ndarray contains any `np.inf` values.
Parameters Parameters
---------- ----------
arr : np.ndarray or output of any Theano op arr : np.ndarray or output of any Theano op
nd : if the output of an Theano op, the node associated to it node : None or an Apply instance.
If the output of a Theano op, the node associated to it.
Returns Returns
------- -------
contains_inf : bool contains_inf : bool
...@@ -133,7 +138,7 @@ def contains_inf(arr, nd=None): ...@@ -133,7 +138,7 @@ def contains_inf(arr, nd=None):
elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray): elif cuda.cuda_available and isinstance(arr, cuda.CudaNdarray):
if (hasattr(theano.sandbox, 'rng_mrg') and if (hasattr(theano.sandbox, 'rng_mrg') and
isinstance( isinstance(
nd.op, node.op,
# It store ints in float container # It store ints in float container
theano.sandbox.rng_mrg.GPU_mrg_uniform)): theano.sandbox.rng_mrg.GPU_mrg_uniform)):
return False return False
...@@ -247,13 +252,14 @@ class NanGuardMode(Mode): ...@@ -247,13 +252,14 @@ class NanGuardMode(Mode):
""" """
error = False error = False
sio = StringIO()
if nan_is_error: if nan_is_error:
if contains_nan(var, nd): if contains_nan(var, nd):
logger.error('NaN detected') print('NaN detected', file=sio)
error = True error = True
if inf_is_error: if inf_is_error:
if contains_inf(var, nd): if contains_inf(var, nd):
logger.error('Inf detected') print('Inf detected', file=sio)
error = True error = True
if big_is_error: if big_is_error:
err = False err = False
...@@ -268,27 +274,29 @@ class NanGuardMode(Mode): ...@@ -268,27 +274,29 @@ class NanGuardMode(Mode):
else: else:
err = (np.abs(var).max() > 1e10) err = (np.abs(var).max() > 1e10)
if err: if err:
logger.error('Big value detected') print('Big value detected', file=sio)
error = True error = True
if error: if error:
if not is_input: if not is_input:
logger.error("NanGuardMode found an error in the" print("NanGuardMode found an error in the"
" output of a node in this variable:") " output of a node in this variable:", file=sio)
logger.error(theano.printing.debugprint(nd, file='str')) print(theano.printing.debugprint(nd, file='str'), file=sio)
else: else:
logger.error("NanGuardMode found an error in the" print("NanGuardMode found an error in the"
" input %d of this node.") " input %d of this node.", file=sio)
logger.error('Node:') print('Node:', file=sio)
logger.error(nd) print(nd, file=sio)
logger.error("The input variable that cause problem:") print("The input variable that cause problem:", file=sio)
logger.error(theano.printing.debugprint(nd, file='str')) print(theano.printing.debugprint(nd, file='str'), file=sio)
msg = sio.getvalue()
if config.NanGuardMode.action == 'raise': if config.NanGuardMode.action == 'raise':
assert False raise AssertionError(msg)
elif config.NanGuardMode.action == 'pdb': elif config.NanGuardMode.action == 'pdb':
print(msg)
import pdb import pdb
pdb.set_trace() pdb.set_trace()
elif config.NanGuardMode.action == 'warn': elif config.NanGuardMode.action == 'warn':
pass # already printed logger.error(msg)
def nan_check(i, node, fn): def nan_check(i, node, fn):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论