提交 96366f99 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a Scan as a Scan sequence input test

上级 b543e6e7
......@@ -1110,6 +1110,23 @@ class TestScan:
utt.assert_allclose(out, vR)
@pytest.mark.parametrize(
"mode", [Mode(linker="cvm", optimizer=None), Mode(linker="cvm")]
)
def test_sequence_is_scan(self, mode):
"""Make sure that a `Scan` can be used as a sequence input to another `Scan`."""
x0 = scalar("x0")
scan_1, _ = scan(lambda x: x + 1, outputs_info={"initial": x0}, n_steps=10)
scan_2, _ = scan(lambda x: x + 1, sequences=[scan_1])
with config.change_flags(mode=mode):
scan_2_fn = function([x0], scan_2)
scan_2_val = scan_2_fn(0.0)
exp_res = np.arange(1, 11) + 1.0
assert np.array_equal(scan_2_val, exp_res)
def test_grad_sitsot(self):
def get_sum_of_grad(inp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论