提交 9ba871c1 authored 作者: Frederic's avatar Frederic

Assert that the opt removed the sum from the graph.

上级 9d99d07f
......@@ -3146,23 +3146,27 @@ class T_local_sum(unittest.TestCase):
f = theano.function([a],t_like(a).sum(d),mode=mode)
assert numpy.allclose(f(input),n_like(input).sum(d))
assert len(f.maker.env.nodes)==nb_nodes[1]
assert f.maker.env.toposort()[-1].op==T.alloc
topo = f.maker.env.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.Sum) for node in topo])
for i in range(3):
f = theano.function([a],t_like(a).sum(i),mode=mode)
assert numpy.allclose(f(input),n_like(input).sum(i))
assert len(f.maker.env.nodes)==nb_nodes[2]
assert f.maker.env.toposort()[-1].op==T.alloc
topo = f.maker.env.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.Sum) for node in topo])
backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False
try:
for d, dd in [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)]:
f = theano.function([a],t_like(a).sum(d).sum(dd),mode=mode)
print f.maker.env.toposort()
assert numpy.allclose(f(input),n_like(input).sum(d).sum(dd))
assert len(f.maker.env.nodes)==nb_nodes[3]
assert f.maker.env.toposort()[-1].op==T.alloc
topo = f.maker.env.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.Sum) for node in topo])
finally:
config.warn.sum_sum_bug = backup
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论