提交 b543e6e7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a merge test for Scan nodes

上级 04c4d86d
...@@ -29,7 +29,9 @@ from aesara.compile.sharedvalue import shared ...@@ -29,7 +29,9 @@ from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian from aesara.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from aesara.graph.basic import Apply, ancestors from aesara.graph.basic import Apply, ancestors
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import MergeOptimizer
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import assert_op from aesara.raise_op import assert_op
...@@ -793,6 +795,27 @@ class TestScan: ...@@ -793,6 +795,27 @@ class TestScan:
assert scan1.owner.op == scan2.owner.op assert scan1.owner.op == scan2.owner.op
assert hash(scan1.owner.op) == hash(scan2.owner.op) assert hash(scan1.owner.op) == hash(scan2.owner.op)
def test_can_merge(self):
"""Make sure that equivalent `Scan` nodes can be merged."""
x = vector("x")
y = vector("y")
c = scalar("c")
scan_a, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c])
scan_b, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c])
scan_c, _ = scan(lambda x, y, c: x + y + c, sequences=[y, x], non_sequences=[c])
assert scan_b is not scan_a
assert scan_c is not scan_a
g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False)
MergeOptimizer().optimize(g)
scan_a_out, scan_b_out, scan_c_out = g.outputs
assert scan_a_out is scan_b_out
assert scan_c_out is not scan_a_out
def test_using_negative_taps_sequence(self): def test_using_negative_taps_sequence(self):
# This test refers to a bug reported on github on May 22 2015 by # This test refers to a bug reported on github on May 22 2015 by
# user june-qijun # user june-qijun
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论