提交 ff0ac3ca authored 作者: Frederic's avatar Frederic

Add the opt tag name inplace_elemwise_optimizer to inplace_opt as this

is the variable name used. I lost 1h before undersanding it isn't the name used in the optdb. Also test that we can disable inplace and fusion opt with MonitorMode.
上级 554b4b59
...@@ -27,9 +27,9 @@ def test_detect_nan(): ...@@ -27,9 +27,9 @@ def test_detect_nan():
assert nan_detected[0] assert nan_detected[0]
def test_optimizers(): def test_optimizer():
""" """
Test that we can remove optimizers Test that we can remove optimizer
""" """
nan_detected = [False] nan_detected = [False]
...@@ -54,3 +54,38 @@ def test_optimizers(): ...@@ -54,3 +54,38 @@ def test_optimizers():
# Test that we still detect the nan # Test that we still detect the nan
assert nan_detected[0] assert nan_detected[0]
def test_not_inplace():
"""
Test that we can remove optimizers including inplace optimizers
"""
nan_detected = [False]
def detect_nan(i, node, fn):
for output in fn.outputs:
if numpy.isnan(output[0]).any():
print '*** NaN detected ***'
theano.printing.debugprint(node)
print 'Inputs : %s' % [input[0] for input in fn.inputs]
print 'Outputs: %s' % [output[0] for output in fn.outputs]
nan_detected[0] = True
break
x = theano.tensor.vector('x')
mode = theano.compile.MonitorMode(post_func=detect_nan)
#mode = mode.excluding('fusion', 'inplace')
mode = mode.excluding('local_elemwise_fusion',
'inplace_elemwise_optimizer')
o = theano.tensor.outer(x, x)
out = theano.tensor.log(o) * o
f = theano.function([x], [out],
mode=mode)
# Test that the fusion wasn't done
assert len(f.maker.fgraph.nodes) == 5
assert not f.maker.fgraph.toposort()[-1].op.destroy_map
f([0, 0]) # log(0) * 0 = -inf * 0 = NaN
# Test that we still detect the nan
assert nan_detected[0]
...@@ -273,8 +273,8 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -273,8 +273,8 @@ def inplace_elemwise_optimizer_op(OP):
return inplace_elemwise_optimizer return inplace_elemwise_optimizer
inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise) inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise)
compile.optdb.register('inplace_opt', inplace_elemwise_optimizer, 75, compile.optdb.register('inplace_opt', inplace_elemwise_optimizer, 75,
'inplace_elemwise_optimizer',
'fast_run', 'inplace') 'fast_run', 'inplace')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论