提交 7a6c42e6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Guarantee unique signature variable names during Numba conversion

上级 9af376f6
...@@ -25,6 +25,7 @@ from aesara.link.utils import ( ...@@ -25,6 +25,7 @@ from aesara.link.utils import (
compile_function_src, compile_function_src,
fgraph_to_python, fgraph_to_python,
get_name_for_object, get_name_for_object,
unique_name_generator,
) )
from aesara.scalar.basic import ( from aesara.scalar.basic import (
Cast, Cast,
...@@ -352,11 +353,15 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -352,11 +353,15 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
else: else:
scalar_func = getattr(func_package, scalar_func_name) scalar_func = getattr(func_package, scalar_func_name)
input_names = ", ".join([v.auto_name for v in node.inputs]) scalar_op_fn_name = get_name_for_object(scalar_func)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"], suffix_sep="_"
)
input_names = ", ".join([unique_names(v, force_unique=True) for v in node.inputs])
global_env = {"scalar_func": scalar_func} global_env = {"scalar_func": scalar_func}
scalar_op_fn_name = get_name_for_object(scalar_func)
scalar_op_src = f""" scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}): def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names}) return scalar_func({input_names})
...@@ -370,8 +375,6 @@ def {scalar_op_fn_name}({input_names}): ...@@ -370,8 +375,6 @@ def {scalar_op_fn_name}({input_names}):
def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs): def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs) scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs)
input_names = ", ".join([v.auto_name for v in node.inputs])
if use_signature: if use_signature:
signature = [create_numba_signature(node, force_scalar=True)] signature = [create_numba_signature(node, force_scalar=True)]
else: else:
...@@ -381,6 +384,12 @@ def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwarg ...@@ -381,6 +384,12 @@ def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwarg
global_env = {"scalar_op": scalar_op_fn, "numba_vectorize": numba_vectorize} global_env = {"scalar_op": scalar_op_fn, "numba_vectorize": numba_vectorize}
elemwise_fn_name = f"elemwise_{get_name_for_object(scalar_op_fn)}" elemwise_fn_name = f"elemwise_{get_name_for_object(scalar_op_fn)}"
unique_names = unique_name_generator(
[elemwise_fn_name, "scalar_op", "scalar_op", "numba_vectorize"], suffix_sep="_"
)
input_names = ", ".join([unique_names(v, force_unique=True) for v in node.inputs])
elemwise_src = f""" elemwise_src = f"""
@numba_vectorize @numba_vectorize
def {elemwise_fn_name}({input_names}): def {elemwise_fn_name}({input_names}):
...@@ -574,7 +583,11 @@ def create_index_func(node, objmode=False): ...@@ -574,7 +583,11 @@ def create_index_func(node, objmode=False):
) )
index_start_idx = 1 + int(set_or_inc) index_start_idx = 1 + int(set_or_inc)
input_names = [v.auto_name for v in node.inputs] unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:]) op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None) idx_list = getattr(node.op, "idx_list", None)
...@@ -756,7 +769,10 @@ def numba_funcify_AllocEmpty(op, node, **kwargs): ...@@ -756,7 +769,10 @@ def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {"np": np, "to_scalar": to_scalar, "dtype": op.dtype} global_env = {"np": np, "to_scalar": to_scalar, "dtype": op.dtype}
shape_var_names = [v.auto_name for v in node.inputs] unique_names = unique_name_generator(
["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_item_names = [f"{name}_item" for name in shape_var_names] shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent( shapes_to_items_src = indent(
"\n".join( "\n".join(
...@@ -785,7 +801,11 @@ def numba_funcify_Alloc(op, node, **kwargs): ...@@ -785,7 +801,11 @@ def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np, "to_scalar": to_scalar} global_env = {"np": np, "to_scalar": to_scalar}
shape_var_names = [v.auto_name for v in node.inputs[1:]] unique_names = unique_name_generator(
["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_",
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
shape_var_item_names = [f"{name}_item" for name in shape_var_names] shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent( shapes_to_items_src = indent(
"\n".join( "\n".join(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论