提交 c6bfe4d4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test to make sure dots are pushed out.

上级 0ce29584
......@@ -3120,6 +3120,25 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(vnh0, tnh0, atol=1e-6)
utt.assert_allclose(vnW, tnW, atol=1e-6)
def test_pushout_dot(self):
W = tensor.matrix('W')
h = tensor.matrix('h')
o, _ = theano.scan(lambda hi, him1, W: (hi, tensor.dot(hi+him1, W)),
outputs_info=[tensor.zeros([h.shape[1]]), None],
sequences=[h],
non_sequences=[W])
f = theano.function([W, h], o, mode=mode_with_opt)
scan_nodes = [x for x in f.maker.fgraph.toposort()
if isinstance(x.op,
theano.scan_module.scan_op.Scan)]
assert len(scan_nodes) == 1
scan_op = scan_nodes[0].op
assert not any(isinstance(n.op, tensor.Dot) for n in
scan_op.fn.maker.fgraph.apply_nodes)
def test_pushout_all(self):
W1 = tensor.matrix('W1')
W2 = tensor.matrix('W2')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论