提交 0a2d575e authored 作者: nouiz's avatar nouiz

Merge pull request #1120 from pascanur/bug_seqOptimizer

fix Scan bug reported by abalkin
...@@ -837,9 +837,17 @@ class ScanSaveMem(gof.Optimizer): ...@@ -837,9 +837,17 @@ class ScanSaveMem(gof.Optimizer):
if not (idx in replaced_outs) and not idx in not_required: if not (idx in replaced_outs) and not idx in not_required:
nw_pos = compress_map[idx] nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])] old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node
old_scan_is_used = [scan_utils.find_up(new.owner, node)
for old, new in old_new]
if any(old_scan_is_used):
return False
remove = [old.owner for (old, new) in old_new]
# As Fred suggested assert that also the old node is not in
# the Graph as that will make things suboptimal
remove.append(node)
fgraph.replace_all_validate_remove(old_new, fgraph.replace_all_validate_remove(old_new,
remove=[node], remove,
reason='scan_save_mem') reason='scan_save_mem')
def apply(self, fgraph): def apply(self, fgraph):
......
...@@ -3292,6 +3292,18 @@ class T_Scan(unittest.TestCase): ...@@ -3292,6 +3292,18 @@ class T_Scan(unittest.TestCase):
cost = x.sum() cost = x.sum()
self.assertRaises(ValueError, tensor.grad, cost, y0) self.assertRaises(ValueError, tensor.grad, cost, y0)
def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = tensor.ones(())
values, _ = theano.scan(lambda x: ([x], (), theano.scan_module.until(x)),
outputs_info=[var], n_steps=2)
tmp_fn = theano.function([var], values)
scan_nodes = [x for x in tmp_fn.maker.fgraph.toposort()
if isinstance(x.op,
theano.scan_module.scan_op.Scan)]
assert len(scan_nodes) == 1
def test_speed(): def test_speed():
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论