提交 5e7b095c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a push_out_non_seq_scan test for an OpFromGraph with a shared variable

上级 e65b0c51
......@@ -4,6 +4,7 @@ import pytest
import aesara
import aesara.tensor.basic as at
from aesara import function, scan, shared
from aesara.compile.builders import OpFromGraph
from aesara.compile.io import In
from aesara.compile.mode import get_default_mode
from aesara.configdefaults import config
......@@ -550,6 +551,28 @@ class TestPushOutNonSeqScan:
utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1])
def test_OpFromGraph_shared(self):
"""Make sure that a simple `OpFromGraph` with a shared variable can be pushed out."""
y = shared(1.0, name="y")
test_ofg = OpFromGraph([], [1 + y])
def inner_func():
return test_ofg()
out, out_updates = aesara.scan(inner_func, n_steps=10)
out_fn = function([], out, updates=out_updates)
res = out_fn()
assert np.array_equal(res, np.repeat(2.0, 10))
y.set_value(2.0)
res = out_fn()
assert np.array_equal(res, np.repeat(3.0, 10))
class TestPushOutAddScan:
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论