提交 57b344e7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Scan: make codegen more readable

上级 dca7f5d9
...@@ -199,13 +199,10 @@ def create_tuple_creator(f, n): ...@@ -199,13 +199,10 @@ def create_tuple_creator(f, n):
def create_tuple_string(x): def create_tuple_string(x):
args = ", ".join(x + ([""] if len(x) == 1 else [])) if len(x) == 1:
return f"({args})" return f"({x[0]},)"
else:
return f"({', '.join(x)})"
def create_arg_string(x):
args = ", ".join(x)
return args
@numba.extending.intrinsic @numba.extending.intrinsic
......
...@@ -151,7 +151,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -151,7 +151,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Inner-inputs are ordered as follows: # Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# untraced-sit-sot-inputs + non-sequences. # untraced-sit-sot-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: list[str] = [] temp_0d_storage_alloc_stmts: list[str] = []
inner_in_exprs_scalar: list[str] = [] inner_in_exprs_scalar: list[str] = []
inner_in_exprs: list[str] = [] inner_in_exprs: list[str] = []
...@@ -169,7 +169,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -169,7 +169,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
) )
temp_storage = f"{storage_name}_temp_scalar_{tap_offset}" temp_storage = f"{storage_name}_temp_scalar_{tap_offset}"
storage_dtype = outer_in_var.type.numpy_dtype.name storage_dtype = outer_in_var.type.numpy_dtype.name
temp_scalar_storage_alloc_stmts.append( temp_0d_storage_alloc_stmts.append(
f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})" f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})"
) )
inner_in_exprs_scalar.append( inner_in_exprs_scalar.append(
...@@ -181,7 +181,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -181,7 +181,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
storage_name storage_name
if tap_offset is None if tap_offset is None
else idx_to_str( else idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=False storage_name, tap_offset, size=storage_size_var, allow_scalar=True
) )
) )
inner_in_exprs.append(indexed_inner_in_str) inner_in_exprs.append(indexed_inner_in_str)
...@@ -366,23 +366,27 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -366,23 +366,27 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name) curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name)
curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position] curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position]
known_static_shape = all(dim is not None for dim in curr_nit_sot.type.shape)
if known_static_shape:
storage_shape = create_tuple_string(
(storage_size_name, *(map(str, curr_nit_sot.type.shape)))
)
else:
storage_shape = create_tuple_string( storage_shape = create_tuple_string(
[storage_size_name] + ["0"] * curr_nit_sot.ndim (storage_size_name, *(["0"] * curr_nit_sot.ndim))
) )
storage_dtype = curr_nit_sot.type.numpy_dtype.name storage_dtype = curr_nit_sot.type.numpy_dtype.name
storage_alloc_stmts.append( storage_alloc_stmts.append(
dedent( dedent(
f""" f"""
{storage_size_name} = ({outer_in_name}).item() {storage_size_name} = {outer_in_name}.item()
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype}) {storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
""" """
).strip() ).strip()
) )
if curr_nit_sot.type.ndim > 0: if not known_static_shape:
storage_alloc_stmts.append(f"{outer_in_name}_ready = False")
# In this case, we don't know the shape of the output storage # In this case, we don't know the shape of the output storage
# array until we get some output from the inner-function. # array until we get some output from the inner-function.
# With the following we add delayed output storage initialization: # With the following we add delayed output storage initialization:
...@@ -392,9 +396,8 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -392,9 +396,8 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
inner_out_post_processing_stmts.append( inner_out_post_processing_stmts.append(
dedent( dedent(
f""" f"""
if not {outer_in_name}_ready: if i == 0:
{storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype}) {storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype})
{outer_in_name}_ready = True
""" """
).strip() ).strip()
) )
...@@ -409,10 +412,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -409,10 +412,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
assert len(inner_in_exprs) == len(op.fgraph.inputs) assert len(inner_in_exprs) == len(op.fgraph.inputs)
inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar) inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar)
inner_in_args = create_arg_string(inner_in_exprs) # Break inputs in new lines, just for readability of the source code
inner_in_args = f",\n{' ' * 12}".join(inner_in_exprs)
inner_outputs = create_tuple_string(inner_output_names) inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(storage_alloc_stmts) input_storage_block = "\n".join(storage_alloc_stmts)
input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts) input_temp_0d_storage_block = "\n".join(temp_0d_storage_alloc_stmts)
output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts) output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
...@@ -426,32 +430,29 @@ def scan({", ".join(outer_in_names)}): ...@@ -426,32 +430,29 @@ def scan({", ".join(outer_in_names)}):
{indent(input_storage_block, " " * 4)} {indent(input_storage_block, " " * 4)}
{indent(input_temp_scalar_storage_block, " " * 4)} {indent(input_temp_0d_storage_block, " " * 4)}
i = 0 i = 0
cond = np.array(False) cond = np.array(False)
while i < n_steps and not cond.item(): while i < n_steps and not cond.item():
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)} {indent(inner_scalar_in_args_to_temp_storage, " " * 8)}
{inner_outputs} = scan_inner_func({inner_in_args}) {inner_outputs} = scan_inner_func(
{inner_in_args}
)
{indent(inner_out_post_processing_block, " " * 8)} {indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)} {indent(inner_out_to_outer_out_stmts, " " * 8)}
i += 1 i += 1
{indent(output_storage_post_processing_block, " " * 4)} {indent(output_storage_post_processing_block, " " * 4)}
return {create_arg_string(outer_output_names)} return {", ".join(outer_output_names)}
""" """
global_env = {
"np": np,
"scan_inner_func": scan_inner_func,
}
scan_op_fn = compile_numba_function_src( scan_op_fn = compile_numba_function_src(
scan_op_src, scan_op_src,
"scan", "scan",
{**globals(), **global_env}, globals() | {"np": np, "scan_inner_func": scan_inner_func},
) )
if inner_func_cache_key is None: if inner_func_cache_key is None:
......
...@@ -4,10 +4,7 @@ import numpy as np ...@@ -4,10 +4,7 @@ import numpy as np
from numba.np.unsafe import ndarray as numba_ndarray from numba.np.unsafe import ndarray as numba_ndarray
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
create_arg_string,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import compile_function_src from pytensor.link.utils import compile_function_src
from pytensor.tensor import NoneConst from pytensor.tensor import NoneConst
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...@@ -48,7 +45,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): ...@@ -48,7 +45,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
func = dedent( func = dedent(
f""" f"""
def specify_shape(x, {create_arg_string(shape_input_names)}): def specify_shape(x, {", ".join(shape_input_names)}):
{"; ".join(func_conditions)} {"; ".join(func_conditions)}
return x return x
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论