提交 7b3abe25 authored 作者: Frederic's avatar Frederic

Make monitormode detect_nan example not crash with random numbers

上级 9636481a
...@@ -343,9 +343,16 @@ can be achieved as follows: ...@@ -343,9 +343,16 @@ can be achieved as follows:
import theano import theano
# This is the current suggested detect_nan implementation to
# show you how it work. That way, you can modify it for your
# need. If you want exactly this method, you can use
# ``theano.compile.monitormode.detect_nan`` that will always
# contain the current suggested version.
def detect_nan(i, node, fn): def detect_nan(i, node, fn):
for output in fn.outputs: for output in fn.outputs:
if numpy.isnan(output[0]).any(): if (not isinstance(numpy.random.RandomState, output[0]) and
numpy.isnan(output[0]).any()):
print '*** NaN detected ***' print '*** NaN detected ***'
theano.printing.debugprint(node) theano.printing.debugprint(node)
print 'Inputs : %s' % [input[0] for input in fn.inputs] print 'Inputs : %s' % [input[0] for input in fn.inputs]
......
...@@ -80,3 +80,14 @@ class MonitorMode(Mode): ...@@ -80,3 +80,14 @@ class MonitorMode(Mode):
ret.pre_func = self.pre_func ret.pre_func = self.pre_func
ret.post_func = self.post_func ret.post_func = self.post_func
return ret return ret
def detect_nan(i, node, fn):
for output in fn.outputs:
if (not isinstance(numpy.random.RandomState, output[0]) and
numpy.isnan(output[0]).any()):
print '*** NaN detected ***'
theano.printing.debugprint(node)
print 'Inputs : %s' % [input[0] for input in fn.inputs]
print 'Outputs: %s' % [output[0] for output in fn.outputs]
break
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论