提交 8eaa12f2 authored 作者: AdeB's avatar AdeB

log_sum_exp: add FAST_COMPILE case to test1 where an MaxAndArgmax op is…

log_sum_exp: add FAST_COMPILE case to test1 where an MaxAndArgmax op is introduced by the optimisation.
上级 39a5efab
...@@ -6663,9 +6663,17 @@ def test_local_log_sum_exp1(): ...@@ -6663,9 +6663,17 @@ def test_local_log_sum_exp1():
MODE = theano.compile.get_default_mode().including('local_log_sum_exp') MODE = theano.compile.get_default_mode().including('local_log_sum_exp')
f = function([x], y, mode=MODE) f = function([x], y, mode=MODE)
assert (theano.scalar.basic.maximum for node in f.maker.fgraph.toposort():
in [node.op.scalar_op for node in f.maker.fgraph.toposort() if (hasattr(node.op, 'scalar_op') and
if (node.op and hasattr(node.op, 'scalar_op'))]) node.op.scalar_op == theano.scalar.basic.maximum):
return
# in mode FAST_COMPILE, the optimisations don't replace the
# MaxAndArgmax op.
if isinstance(node.op, theano.tensor.MaxAndArgmax):
return
raise Exception('No maximum detected after log_sum_exp optimisation')
def test_local_log_sum_exp2(): def test_local_log_sum_exp2():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论