提交 1fa22d8f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Minor numba Scan tweaks

上级 85235623
...@@ -222,14 +222,16 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -222,14 +222,16 @@ def numba_funcify_Scan(op, node, **kwargs):
# the storage array. # the storage array.
# This is needed when the output storage array does not have a length # This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`. # equal to the number of taps plus `n_steps`.
# If the storage size only allows one entry, there's nothing to rotate
output_storage_post_proc_stmts.append( output_storage_post_proc_stmts.append(
dedent( dedent(
f""" f"""
if (i + {tap_size}) > {storage_size}: if 1 < {storage_size} < (i + {tap_size}):
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) {outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] if {outer_in_name}_shift > 0:
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
""" """
).strip() ).strip()
) )
...@@ -417,4 +419,4 @@ def scan({", ".join(outer_in_names)}): ...@@ -417,4 +419,4 @@ def scan({", ".join(outer_in_names)}):
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})
return numba_basic.numba_njit(scan_op_fn) return numba_basic.numba_njit(scan_op_fn, boundscheck=False)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论