提交 41253900 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added proper dtype declaration

上级 39f29083
......@@ -109,8 +109,9 @@ def numba_funcify_Scan(op, node, **kwargs):
curr_nit_sot_position = input_names.index(_name) - n_seqs
curr_nit_sot = inner_fg.outputs[curr_nit_sot_position]
mem_shape = ["1"] * curr_nit_sot.ndim
curr_dtype = curr_nit_sot.type.numpy_dtype.name
allocate_mem_to_nit_sot += f"""
{_name} = [np.zeros(({create_arg_string(mem_shape)}))]*{_name}.item()
{_name} = [np.zeros(({create_arg_string(mem_shape)}), dtype=np.{curr_dtype})]*{_name}.item()
"""
# The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论