提交 39f29083 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed Numba Scan failures on multidimensional nit sot outputs

上级 ac14c1dd
...@@ -106,8 +106,11 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -106,8 +106,11 @@ def numba_funcify_Scan(op, node, **kwargs):
# In case of nit-sots we are provided shape of the array # In case of nit-sots we are provided shape of the array
# instead of actual arrays like other cases, hence we # instead of actual arrays like other cases, hence we
# allocate space for the results accordingly. # allocate space for the results accordingly.
curr_nit_sot_position = input_names.index(_name) - n_seqs
curr_nit_sot = inner_fg.outputs[curr_nit_sot_position]
mem_shape = ["1"] * curr_nit_sot.ndim
allocate_mem_to_nit_sot += f""" allocate_mem_to_nit_sot += f"""
{_name} = np.zeros({_name}.item()) {_name} = [np.zeros(({create_arg_string(mem_shape)}))]*{_name}.item()
""" """
# The non_seqs are passed to inner function as-is # The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names inner_in_indexed += outer_in_non_seqs_names
......
...@@ -3048,7 +3048,7 @@ def test_scan_tap_output(): ...@@ -3048,7 +3048,7 @@ def test_scan_tap_output():
x_t = x_tm1 + 1 x_t = x_tm1 + 1
x_t.name = "x_t" x_t.name = "x_t"
y_t.name = "y_t" y_t.name = "y_t"
return x_t, y_t, z_t return x_t, y_t, aet.fill((10,), z_t)
scan_res, _ = scan( scan_res, _ = scan(
fn=input_step_fn, fn=input_step_fn,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论