提交 038fb754 authored 作者: Frederic's avatar Frederic

Better NanGuardMode output and add config.NanGuardMode.action={raise,pdb,warn}

上级 777caa57
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
import numpy as np import numpy as np
import theano import theano
from theano.configparser import config, AddConfigVar, BoolParam from theano.configparser import config, AddConfigVar, BoolParam, EnumStr
import theano.tensor as T import theano.tensor as T
import theano.sandbox.cuda as cuda import theano.sandbox.cuda as cuda
from theano.compile import Mode from theano.compile import Mode
...@@ -24,6 +24,11 @@ AddConfigVar('NanGuardMode.big_is_error', ...@@ -24,6 +24,11 @@ AddConfigVar('NanGuardMode.big_is_error',
BoolParam(True), BoolParam(True),
in_c_key=False) in_c_key=False)
AddConfigVar('NanGuardMode.action',
"What NanGuardMode do when it find a problem",
EnumStr('raise', 'warn', 'pdb'),
in_c_key=False)
logger = logging.getLogger("theano.compile.nanguardmode") logger = logging.getLogger("theano.compile.nanguardmode")
...@@ -266,20 +271,23 @@ class NanGuardMode(Mode): ...@@ -266,20 +271,23 @@ class NanGuardMode(Mode):
logger.error('Big value detected') logger.error('Big value detected')
error = True error = True
if error: if error:
if is_input: if not is_input:
logger.error('In an input') logger.error("NanGuardMode found an error in the"
" output of a node in this variable:")
logger.error(theano.printing.debugprint(nd, file='str'))
else: else:
logger.error('In an output') logger.error("NanGuardMode found an error in the"
logger.error('Inputs: ') " input %d of this node.")
for ivar, ival in zip(nd.inputs, f.inputs): logger.error('Node:')
logger.error('var') logger.error(nd)
logger.error(ivar) logger.error("The input variable that cause problem:")
logger.error(theano.printing.min_informative_str(ivar)) logger.error(theano.printing.debugprint(nd, file='str'))
logger.error('val') if config.NanGuardMode.action == 'raise':
logger.error(ival) assert False
logger.error('Node:') elif config.NanGuardMode.action == 'pdb':
logger.error(nd) import pdb;pdb.set_trace()
assert False elif config.NanGuardMode.action == 'warn':
pass # already printed
def nan_check(i, node, fn): def nan_check(i, node, fn):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论