提交 97f9ad48 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename tests.scan.test_basic.test_pushout_dot and add a test description

上级 d330a6fc
...@@ -3014,18 +3014,29 @@ class TestScan: ...@@ -3014,18 +3014,29 @@ class TestScan:
utt.assert_allclose(vnh0, tnh0, atol=1e-6) utt.assert_allclose(vnh0, tnh0, atol=1e-6)
utt.assert_allclose(vnW, tnW, atol=1e-6) utt.assert_allclose(vnW, tnW, atol=1e-6)
def test_pushout_dot(self): def test_inner_replace_dot(self):
"""
This tests that rewrites are applied to the inner-graph.
In particular, BLAS-based rewrites that remove the original dot product.
This was previously a test with a name that implied it was testing the
`Scan` push-out rewrites, but it wasn't testing that at all, because the
rewrites were never being applied.
"""
W = matrix("W") W = matrix("W")
h = matrix("h") h = matrix("h")
mode = mode_with_opt # .excluding("BlasOpt")
o, _ = scan( o, _ = scan(
lambda hi, him1, W: (hi, dot(hi + him1, W)), lambda hi, him1, W: (hi, dot(hi + him1, W)),
outputs_info=[at.zeros([h.shape[1]]), None], outputs_info=[at.zeros([h.shape[1]]), None],
sequences=[h], sequences=[h],
non_sequences=[W], non_sequences=[W],
mode=mode,
) )
f = function([W, h], o, mode=mode_with_opt) f = function([W, h], o, mode=mode)
scan_nodes = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)] scan_nodes = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert len(scan_nodes) == 1 assert len(scan_nodes) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论