提交 57b344e7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Scan: make codegen more readable

上级 dca7f5d9
......@@ -199,13 +199,10 @@ def create_tuple_creator(f, n):
def create_tuple_string(x):
args = ", ".join(x + ([""] if len(x) == 1 else []))
return f"({args})"
def create_arg_string(x):
args = ", ".join(x)
return args
if len(x) == 1:
return f"({x[0]},)"
else:
return f"({', '.join(x)})"
@numba.extending.intrinsic
......
......@@ -151,7 +151,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# untraced-sit-sot-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: list[str] = []
temp_0d_storage_alloc_stmts: list[str] = []
inner_in_exprs_scalar: list[str] = []
inner_in_exprs: list[str] = []
......@@ -169,7 +169,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
temp_storage = f"{storage_name}_temp_scalar_{tap_offset}"
storage_dtype = outer_in_var.type.numpy_dtype.name
temp_scalar_storage_alloc_stmts.append(
temp_0d_storage_alloc_stmts.append(
f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})"
)
inner_in_exprs_scalar.append(
......@@ -181,7 +181,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
storage_name
if tap_offset is None
else idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
storage_name, tap_offset, size=storage_size_var, allow_scalar=True
)
)
inner_in_exprs.append(indexed_inner_in_str)
......@@ -366,23 +366,27 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name)
curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position]
storage_shape = create_tuple_string(
[storage_size_name] + ["0"] * curr_nit_sot.ndim
)
known_static_shape = all(dim is not None for dim in curr_nit_sot.type.shape)
if known_static_shape:
storage_shape = create_tuple_string(
(storage_size_name, *(map(str, curr_nit_sot.type.shape)))
)
else:
storage_shape = create_tuple_string(
(storage_size_name, *(["0"] * curr_nit_sot.ndim))
)
storage_dtype = curr_nit_sot.type.numpy_dtype.name
storage_alloc_stmts.append(
dedent(
f"""
{storage_size_name} = ({outer_in_name}).item()
{storage_size_name} = {outer_in_name}.item()
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
"""
).strip()
)
if curr_nit_sot.type.ndim > 0:
storage_alloc_stmts.append(f"{outer_in_name}_ready = False")
if not known_static_shape:
# In this case, we don't know the shape of the output storage
# array until we get some output from the inner-function.
# With the following we add delayed output storage initialization:
......@@ -392,9 +396,8 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
inner_out_post_processing_stmts.append(
dedent(
f"""
if not {outer_in_name}_ready:
if i == 0:
{storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype})
{outer_in_name}_ready = True
"""
).strip()
)
......@@ -409,10 +412,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
assert len(inner_in_exprs) == len(op.fgraph.inputs)
inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar)
inner_in_args = create_arg_string(inner_in_exprs)
# Break inputs in new lines, just for readability of the source code
inner_in_args = f",\n{' ' * 12}".join(inner_in_exprs)
inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(storage_alloc_stmts)
input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts)
input_temp_0d_storage_block = "\n".join(temp_0d_storage_alloc_stmts)
output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
......@@ -426,32 +430,29 @@ def scan({", ".join(outer_in_names)}):
{indent(input_storage_block, " " * 4)}
{indent(input_temp_scalar_storage_block, " " * 4)}
{indent(input_temp_0d_storage_block, " " * 4)}
i = 0
cond = np.array(False)
while i < n_steps and not cond.item():
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}
{inner_outputs} = scan_inner_func({inner_in_args})
{inner_outputs} = scan_inner_func(
{inner_in_args}
)
{indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)}
i += 1
{indent(output_storage_post_processing_block, " " * 4)}
return {create_arg_string(outer_output_names)}
return {", ".join(outer_output_names)}
"""
global_env = {
"np": np,
"scan_inner_func": scan_inner_func,
}
scan_op_fn = compile_numba_function_src(
scan_op_src,
"scan",
{**globals(), **global_env},
globals() | {"np": np, "scan_inner_func": scan_inner_func},
)
if inner_func_cache_key is None:
......
......@@ -4,10 +4,7 @@ import numpy as np
from numba.np.unsafe import ndarray as numba_ndarray
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_arg_string,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
from pytensor.link.utils import compile_function_src
from pytensor.tensor import NoneConst
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -48,7 +45,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
def specify_shape(x, {", ".join(shape_input_names)}):
{"; ".join(func_conditions)}
return x
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论