提交 57d0f256 authored 作者: Frederic's avatar Frederic

Allow to include/exclude/require optimizers for the MonitorMode.

上级 c9990cc4
...@@ -62,3 +62,21 @@ class MonitorMode(Mode): ...@@ -62,3 +62,21 @@ class MonitorMode(Mode):
fn() fn()
if self.post_func is not None: if self.post_func is not None:
self.post_func(i, node, fn) self.post_func(i, node, fn)
def including(self, *tags):
ret = super(MonitorMode, self).including(*tags)
ret.pre_func = self.pre_func
ret.post_func = self.post_func
return ret
def excluding(self, *tags):
ret = super(MonitorMode, self).excluding(*tags)
ret.pre_func = self.pre_func
ret.post_func = self.post_func
return ret
def requiring(self, *tags):
ret = super(MonitorMode, self).requiring(*tags)
ret.pre_func = self.pre_func
ret.post_func = self.post_func
return ret
...@@ -25,3 +25,32 @@ def test_detect_nan(): ...@@ -25,3 +25,32 @@ def test_detect_nan():
post_func=detect_nan)) post_func=detect_nan))
f(0) # log(0) * 0 = -inf * 0 = NaN f(0) # log(0) * 0 = -inf * 0 = NaN
assert nan_detected[0] assert nan_detected[0]
def test_optimizers():
"""
Test that we can remove optimizers
"""
nan_detected = [False]
def detect_nan(i, node, fn):
for output in fn.outputs:
if numpy.isnan(output[0]).any():
print '*** NaN detected ***'
theano.printing.debugprint(node)
print 'Inputs : %s' % [input[0] for input in fn.inputs]
print 'Outputs: %s' % [output[0] for output in fn.outputs]
nan_detected[0] = True
break
x = theano.tensor.dscalar('x')
mode = theano.compile.MonitorMode(post_func=detect_nan)
mode = mode.excluding('fusion')
f = theano.function([x], [theano.tensor.log(x) * x],
mode=mode)
# Test that the fusion wasn't done
assert len(f.maker.fgraph.nodes) == 2
f(0) # log(0) * 0 = -inf * 0 = NaN
# Test that we still detect the nan
assert nan_detected[0]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论