提交 34ecd3c8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

test the fix for the merge optimization

上级 48d21c36
......@@ -3463,6 +3463,27 @@ class T_Scan(unittest.TestCase):
)
tensor.grad(out[-1], w)
def test_scan_merge_nodes(self):
inps = tensor.vector()
state = tensor.scalar()
y1, _ = theano.scan(lambda x,y: x*y,
sequences = inps,
outputs_info = state,
n_steps = 5)
y2, _ = theano.scan(lambda x,y : (x+y, theano.scan_module.until(x>0)),
sequences = inps,
outputs_info = state,
n_steps = 5)
scan_node1 = y1.owner.inputs[0].owner
assert isinstance(scan_node1.op, theano.scan_module.scan_op.Scan)
scan_node2 = y2.owner.inputs[0].owner
assert isinstance(scan_node2.op, theano.scan_module.scan_op.Scan)
opt_obj = theano.scan_module.scan_opt.ScanMerge()
# Test the method belongs_to of this class. Specifically see if it
# detects the two scan_nodes as not being similar
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
def test_speed():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论