提交 5e7b095c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a push_out_non_seq_scan test for an OpFromGraph with a shared variable

上级 e65b0c51
...@@ -4,6 +4,7 @@ import pytest ...@@ -4,6 +4,7 @@ import pytest
import aesara import aesara
import aesara.tensor.basic as at import aesara.tensor.basic as at
from aesara import function, scan, shared from aesara import function, scan, shared
from aesara.compile.builders import OpFromGraph
from aesara.compile.io import In from aesara.compile.io import In
from aesara.compile.mode import get_default_mode from aesara.compile.mode import get_default_mode
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -550,6 +551,28 @@ class TestPushOutNonSeqScan: ...@@ -550,6 +551,28 @@ class TestPushOutNonSeqScan:
utt.assert_allclose(output_opt[0], output_no_opt[0]) utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1]) utt.assert_allclose(output_opt[1], output_no_opt[1])
def test_OpFromGraph_shared(self):
"""Make sure that a simple `OpFromGraph` with a shared variable can be pushed out."""
y = shared(1.0, name="y")
test_ofg = OpFromGraph([], [1 + y])
def inner_func():
return test_ofg()
out, out_updates = aesara.scan(inner_func, n_steps=10)
out_fn = function([], out, updates=out_updates)
res = out_fn()
assert np.array_equal(res, np.repeat(2.0, 10))
y.set_value(2.0)
res = out_fn()
assert np.array_equal(res, np.repeat(3.0, 10))
class TestPushOutAddScan: class TestPushOutAddScan:
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论