提交 39f29083 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed Numba Scan failures on multidimensional nit sot outputs

上级 ac14c1dd
......@@ -106,8 +106,11 @@ def numba_funcify_Scan(op, node, **kwargs):
# In case of nit-sots we are provided shape of the array
# instead of actual arrays like other cases, hence we
# allocate space for the results accordingly.
curr_nit_sot_position = input_names.index(_name) - n_seqs
curr_nit_sot = inner_fg.outputs[curr_nit_sot_position]
mem_shape = ["1"] * curr_nit_sot.ndim
allocate_mem_to_nit_sot += f"""
{_name} = np.zeros({_name}.item())
{_name} = [np.zeros(({create_arg_string(mem_shape)}))]*{_name}.item()
"""
# The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names
......
......@@ -3048,7 +3048,7 @@ def test_scan_tap_output():
x_t = x_tm1 + 1
x_t.name = "x_t"
y_t.name = "y_t"
return x_t, y_t, z_t
return x_t, y_t, aet.fill((10,), z_t)
scan_res, _ = scan(
fn=input_step_fn,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论