提交 0d698099 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Use unique input names in numba_funcify_Scan

上级 cd3a3ce2
...@@ -54,7 +54,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -54,7 +54,7 @@ def numba_funcify_Scan(op, node, **kwargs):
p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs
p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot
input_names = [n.auto_name for n in node.inputs[1:]] input_names = [f"{n.auto_name}_{i}" for i, n in enumerate(node.inputs[1:])]
outer_in_seqs_names = input_names[:n_seqs] outer_in_seqs_names = input_names[:n_seqs]
outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot] outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot]
outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot] outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot]
......
...@@ -179,3 +179,22 @@ def test_scan_while(): ...@@ -179,3 +179,22 @@ def test_scan_while():
np.array(45).astype(config.floatX), np.array(45).astype(config.floatX),
] ]
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
def test_scan_multiple_none_output():
A = at.dvector("A")
def power_step(prior_result, x):
return prior_result * x, prior_result * x * x, prior_result * x * x * x
result, _ = scan(
power_step,
non_sequences=[A],
outputs_info=[at.ones_like(A), None, None],
n_steps=3,
)
out_fg = FunctionGraph([A], result)
test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py(out_fg, test_input_vals)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论