提交 4b214c13 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

test that scan_savemem is not applied unless it removes the old scan

上级 55a8452a
......@@ -3292,6 +3292,18 @@ class T_Scan(unittest.TestCase):
cost = x.sum()
self.assertRaises(ValueError, tensor.grad, cost, y0)
def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = tensor.ones(())
values, _ = theano.scan(lambda x: ([x], (), theano.scan_module.until(x)),
outputs_info=[var], n_steps=2)
tmp_fn = theano.function([var], values)
scan_nodes = [x for x in tmp_fn.maker.fgraph.toposort()
if isinstance(x.op,
theano.scan_module.scan_op.Scan)]
assert len(scan_nodes) == 1
def test_speed():
#
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论