提交 a244ab1a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove uses of unique_name_generator in numba dispatch

It's more readable and avoids potential bugs when force_unique is not set to True
上级 a3613d13
...@@ -14,7 +14,6 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -14,7 +14,6 @@ from pytensor.link.numba.dispatch.basic import (
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
from pytensor.link.utils import ( from pytensor.link.utils import (
get_name_for_object, get_name_for_object,
unique_name_generator,
) )
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
Add, Add,
...@@ -81,23 +80,21 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -81,23 +80,21 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
scalar_func_numba = generate_fallback_impl(op, node, **kwargs) scalar_func_numba = generate_fallback_impl(op, node, **kwargs)
scalar_op_fn_name = get_name_for_object(scalar_func_numba) scalar_op_fn_name = get_name_for_object(scalar_func_numba)
prefix = "x" if scalar_func_name != "x" else "y"
input_names = [f"{prefix}{i}" for i in range(len(node.inputs))]
input_signature = ", ".join(input_names)
global_env = {"scalar_func_numba": scalar_func_numba} global_env = {"scalar_func_numba": scalar_func_numba}
if input_inner_dtypes is None and output_inner_dtype is None: if input_inner_dtypes is None and output_inner_dtype is None:
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
)
input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs)
if not has_pyx_skip_dispatch: if not has_pyx_skip_dispatch:
scalar_op_src = f""" scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}): def {scalar_op_fn_name}({input_signature}):
return scalar_func_numba({input_names}) return scalar_func_numba({input_signature})
""" """
else: else:
scalar_op_src = f""" scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}): def {scalar_op_fn_name}({input_signature}):
return scalar_func_numba({input_names}, np.intc(1)) return scalar_func_numba({input_signature}, np.intc(1))
""" """
else: else:
...@@ -108,13 +105,6 @@ def {scalar_op_fn_name}({input_names}): ...@@ -108,13 +105,6 @@ def {scalar_op_fn_name}({input_names}):
for i, i_dtype in enumerate(input_inner_dtypes) for i, i_dtype in enumerate(input_inner_dtypes)
} }
global_env.update(input_tmp_dtype_names) global_env.update(input_tmp_dtype_names)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func_numba", *global_env.keys()],
suffix_sep="_",
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
converted_call_args = ", ".join( converted_call_args = ", ".join(
f"direct_cast({i_name}, {i_tmp_dtype_name})" f"direct_cast({i_name}, {i_tmp_dtype_name})"
for i_name, i_tmp_dtype_name in zip( for i_name, i_tmp_dtype_name in zip(
...@@ -123,19 +113,19 @@ def {scalar_op_fn_name}({input_names}): ...@@ -123,19 +113,19 @@ def {scalar_op_fn_name}({input_names}):
) )
if not has_pyx_skip_dispatch: if not has_pyx_skip_dispatch:
scalar_op_src = f""" scalar_op_src = f"""
def {scalar_op_fn_name}({", ".join(input_names)}): def {scalar_op_fn_name}({input_signature}):
return direct_cast(scalar_func_numba({converted_call_args}), output_dtype) return direct_cast(scalar_func_numba({converted_call_args}), output_dtype)
""" """
else: else:
scalar_op_src = f""" scalar_op_src = f"""
def {scalar_op_fn_name}({", ".join(input_names)}): def {scalar_op_fn_name}({input_signature}):
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
""" """
scalar_op_fn = compile_numba_function_src( scalar_op_fn = compile_numba_function_src(
scalar_op_src, scalar_op_src,
scalar_op_fn_name, scalar_op_fn_name,
{**globals(), **global_env}, globals() | global_env,
) )
# Functions that call a function pointer can't be cached # Functions that call a function pointer can't be cached
...@@ -157,8 +147,8 @@ def numba_funcify_Switch(op, node, **kwargs): ...@@ -157,8 +147,8 @@ def numba_funcify_Switch(op, node, **kwargs):
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function.""" """Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_") var_prefix = "x" if binary_op_name != "x" else "y"
input_names = [unique_names(v, force_unique=True) for v in inputs] input_names = [f"{var_prefix}{i}" for i in range(len(inputs))]
input_signature = ", ".join(input_names) input_signature = ", ".join(input_names)
output_expr = binary_op.join(input_names) output_expr = binary_op.join(input_names)
......
...@@ -10,7 +10,6 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -10,7 +10,6 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key, register_funcify_and_cache_key,
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
from pytensor.link.utils import unique_name_generator
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
...@@ -28,15 +27,7 @@ from pytensor.tensor.basic import ( ...@@ -28,15 +27,7 @@ from pytensor.tensor.basic import (
@register_funcify_default_op_cache_key(AllocEmpty) @register_funcify_default_op_cache_key(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs): def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = { shape_var_names = [f"sh{i}" for i in range(len(node.inputs))]
"np": np,
"dtype": np.dtype(op.dtype),
}
unique_names = unique_name_generator(
["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_item_names = [f"{name}_item" for name in shape_var_names] shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent( shapes_to_items_src = indent(
"\n".join( "\n".join(
...@@ -56,7 +47,7 @@ def allocempty({", ".join(shape_var_names)}): ...@@ -56,7 +47,7 @@ def allocempty({", ".join(shape_var_names)}):
""" """
alloc_fn = compile_numba_function_src( alloc_fn = compile_numba_function_src(
alloc_def_src, "allocempty", {**globals(), **global_env} alloc_def_src, "allocempty", globals() | {"np": np, "dtype": np.dtype(op.dtype)}
) )
return numba_basic.numba_njit(alloc_fn) return numba_basic.numba_njit(alloc_fn)
...@@ -64,13 +55,7 @@ def allocempty({", ".join(shape_var_names)}): ...@@ -64,13 +55,7 @@ def allocempty({", ".join(shape_var_names)}):
@register_funcify_and_cache_key(Alloc) @register_funcify_and_cache_key(Alloc)
def numba_funcify_Alloc(op, node, **kwargs): def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np} shape_var_names = [f"sh{i}" for i in range(len(node.inputs) - 1)]
unique_names = unique_name_generator(
["np", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_",
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
shape_var_item_names = [f"{name}_item" for name in shape_var_names] shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent( shapes_to_items_src = indent(
"\n".join( "\n".join(
...@@ -102,7 +87,7 @@ def alloc(val, {", ".join(shape_var_names)}): ...@@ -102,7 +87,7 @@ def alloc(val, {", ".join(shape_var_names)}):
alloc_fn = compile_numba_function_src( alloc_fn = compile_numba_function_src(
alloc_def_src, alloc_def_src,
"alloc", "alloc",
{**globals(), **global_env}, globals() | {"np": np},
) )
cache_key = sha256( cache_key = sha256(
...@@ -207,14 +192,7 @@ def numba_funcify_Eye(op, **kwargs): ...@@ -207,14 +192,7 @@ def numba_funcify_Eye(op, **kwargs):
@register_funcify_default_op_cache_key(MakeVector) @register_funcify_default_op_cache_key(MakeVector)
def numba_funcify_MakeVector(op, node, **kwargs): def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
input_names = [f"x{i}" for i in range(len(node.inputs))]
global_env = {"np": np, "dtype": dtype}
unique_names = unique_name_generator(
["np"],
suffix_sep="_",
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
def create_list_string(x): def create_list_string(x):
args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else [])) args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else []))
...@@ -228,7 +206,7 @@ def makevector({", ".join(input_names)}): ...@@ -228,7 +206,7 @@ def makevector({", ".join(input_names)}):
makevector_fn = compile_numba_function_src( makevector_fn = compile_numba_function_src(
makevector_def_src, makevector_def_src,
"makevector", "makevector",
{**globals(), **global_env}, globals() | {"np": np, "dtype": dtype},
) )
return numba_basic.numba_njit(makevector_fn) return numba_basic.numba_njit(makevector_fn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论