提交 ba436b05 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use get_name_for_object in numba_funcify_ScalarOp and numba_funcify_Elemwise

上级 b3f294a1
...@@ -16,7 +16,11 @@ from aesara.compile.ops import DeepCopyOp, ViewOp ...@@ -16,7 +16,11 @@ from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.utils import compile_function_src, fgraph_to_python from aesara.link.utils import (
compile_function_src,
fgraph_to_python,
get_name_for_object,
)
from aesara.scalar.basic import ( from aesara.scalar.basic import (
Cast, Cast,
Clip, Clip,
...@@ -205,7 +209,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -205,7 +209,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
global_env = {"scalar_func": scalar_func} global_env = {"scalar_func": scalar_func}
scalar_op_fn_name = scalar_func.__name__ scalar_op_fn_name = get_name_for_object(scalar_func)
scalar_op_src = f""" scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}): def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names}) return scalar_func({input_names})
...@@ -223,7 +227,7 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -223,7 +227,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
global_env = {"scalar_op": scalar_op_fn, "numba_vectorize": numba.vectorize} global_env = {"scalar_op": scalar_op_fn, "numba_vectorize": numba.vectorize}
elemwise_fn_name = f"elemwise_{scalar_op_fn.__name__}" elemwise_fn_name = f"elemwise_{get_name_for_object(scalar_op_fn)}"
elemwise_src = f""" elemwise_src = f"""
@numba_vectorize @numba_vectorize
def {elemwise_fn_name}({input_names}): def {elemwise_fn_name}({input_names}):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论