提交 a244ab1a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove uses of unique_name_generator in numba dispatch

It's more readable and avoids potential bugs when force_unique is not set to True
上级 a3613d13
......@@ -14,7 +14,6 @@ from pytensor.link.numba.dispatch.basic import (
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
from pytensor.link.utils import (
get_name_for_object,
unique_name_generator,
)
from pytensor.scalar.basic import (
Add,
......@@ -81,23 +80,21 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
scalar_func_numba = generate_fallback_impl(op, node, **kwargs)
scalar_op_fn_name = get_name_for_object(scalar_func_numba)
prefix = "x" if scalar_func_name != "x" else "y"
input_names = [f"{prefix}{i}" for i in range(len(node.inputs))]
input_signature = ", ".join(input_names)
global_env = {"scalar_func_numba": scalar_func_numba}
if input_inner_dtypes is None and output_inner_dtype is None:
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
)
input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs)
if not has_pyx_skip_dispatch:
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func_numba({input_names})
def {scalar_op_fn_name}({input_signature}):
return scalar_func_numba({input_signature})
"""
else:
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func_numba({input_names}, np.intc(1))
def {scalar_op_fn_name}({input_signature}):
return scalar_func_numba({input_signature}, np.intc(1))
"""
else:
......@@ -108,13 +105,6 @@ def {scalar_op_fn_name}({input_names}):
for i, i_dtype in enumerate(input_inner_dtypes)
}
global_env.update(input_tmp_dtype_names)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func_numba", *global_env.keys()],
suffix_sep="_",
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
converted_call_args = ", ".join(
f"direct_cast({i_name}, {i_tmp_dtype_name})"
for i_name, i_tmp_dtype_name in zip(
......@@ -123,19 +113,19 @@ def {scalar_op_fn_name}({input_names}):
)
if not has_pyx_skip_dispatch:
scalar_op_src = f"""
def {scalar_op_fn_name}({", ".join(input_names)}):
def {scalar_op_fn_name}({input_signature}):
return direct_cast(scalar_func_numba({converted_call_args}), output_dtype)
"""
else:
scalar_op_src = f"""
def {scalar_op_fn_name}({", ".join(input_names)}):
def {scalar_op_fn_name}({input_signature}):
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
"""
scalar_op_fn = compile_numba_function_src(
scalar_op_src,
scalar_op_fn_name,
{**globals(), **global_env},
globals() | global_env,
)
# Functions that call a function pointer can't be cached
......@@ -157,8 +147,8 @@ def numba_funcify_Switch(op, node, **kwargs):
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
input_names = [unique_names(v, force_unique=True) for v in inputs]
var_prefix = "x" if binary_op_name != "x" else "y"
input_names = [f"{var_prefix}{i}" for i in range(len(inputs))]
input_signature = ", ".join(input_names)
output_expr = binary_op.join(input_names)
......
......@@ -10,7 +10,6 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import unique_name_generator
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
......@@ -28,15 +27,7 @@ from pytensor.tensor.basic import (
@register_funcify_default_op_cache_key(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {
"np": np,
"dtype": np.dtype(op.dtype),
}
unique_names = unique_name_generator(
["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_names = [f"sh{i}" for i in range(len(node.inputs))]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
......@@ -56,7 +47,7 @@ def allocempty({", ".join(shape_var_names)}):
"""
alloc_fn = compile_numba_function_src(
alloc_def_src, "allocempty", {**globals(), **global_env}
alloc_def_src, "allocempty", globals() | {"np": np, "dtype": np.dtype(op.dtype)}
)
return numba_basic.numba_njit(alloc_fn)
......@@ -64,13 +55,7 @@ def allocempty({", ".join(shape_var_names)}):
@register_funcify_and_cache_key(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np}
unique_names = unique_name_generator(
["np", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_",
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
shape_var_names = [f"sh{i}" for i in range(len(node.inputs) - 1)]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
......@@ -102,7 +87,7 @@ def alloc(val, {", ".join(shape_var_names)}):
alloc_fn = compile_numba_function_src(
alloc_def_src,
"alloc",
{**globals(), **global_env},
globals() | {"np": np},
)
cache_key = sha256(
......@@ -207,14 +192,7 @@ def numba_funcify_Eye(op, **kwargs):
@register_funcify_default_op_cache_key(MakeVector)
def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype)
global_env = {"np": np, "dtype": dtype}
unique_names = unique_name_generator(
["np"],
suffix_sep="_",
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
input_names = [f"x{i}" for i in range(len(node.inputs))]
def create_list_string(x):
args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else []))
......@@ -228,7 +206,7 @@ def makevector({", ".join(input_names)}):
makevector_fn = compile_numba_function_src(
makevector_def_src,
"makevector",
{**globals(), **global_env},
globals() | {"np": np, "dtype": dtype},
)
return numba_basic.numba_njit(makevector_fn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论