提交 62cee002 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test that JAX scan can handle simple dynamic sequences lengths

上级 82823dec
...@@ -427,3 +427,13 @@ def test_default_mode_excludes_incompatible_rewrites(): ...@@ -427,3 +427,13 @@ def test_default_mode_excludes_incompatible_rewrites():
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
fg = FunctionGraph([A, B], [out]) fg = FunctionGraph([A, B], [out])
compare_jax_and_py(fg, [np.eye(3), np.eye(3)]) compare_jax_and_py(fg, [np.eye(3), np.eye(3)])
def test_dynamic_sequence_length():
x = pt.tensor("x", shape=(None,))
out, _ = scan(lambda x: x + 1, sequences=[x])
f = function([x], out, mode=get_mode("JAX").excluding("scan"))
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
np.testing.assert_allclose(f([]), [])
np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论