提交 b543e6e7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a merge test for Scan nodes

上级 04c4d86d
......@@ -29,7 +29,9 @@ from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config
from aesara.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from aesara.graph.basic import Apply, ancestors
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import MergeOptimizer
from aesara.graph.utils import MissingInputError
from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import assert_op
......@@ -793,6 +795,27 @@ class TestScan:
assert scan1.owner.op == scan2.owner.op
assert hash(scan1.owner.op) == hash(scan2.owner.op)
def test_can_merge(self):
"""Make sure that equivalent `Scan` nodes can be merged."""
x = vector("x")
y = vector("y")
c = scalar("c")
scan_a, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c])
scan_b, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c])
scan_c, _ = scan(lambda x, y, c: x + y + c, sequences=[y, x], non_sequences=[c])
assert scan_b is not scan_a
assert scan_c is not scan_a
g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False)
MergeOptimizer().optimize(g)
scan_a_out, scan_b_out, scan_c_out = g.outputs
assert scan_a_out is scan_b_out
assert scan_c_out is not scan_a_out
def test_using_negative_taps_sequence(self):
# This test refers to a bug reported on github on May 22 2015 by
# user june-qijun
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论