提交 65ae5df3 authored 作者: Shi Fan's avatar Shi Fan 提交者: Brandon T. Willard

Remove intermediate function call in numba_funcify_Elemwise

上级 da28a260
import inspect
import operator
import warnings
from functools import reduce, singledispatch
......@@ -438,34 +439,22 @@ def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs
signature = []
numba_vectorize = numba.vectorize(signature, identity=identity)
global_env = {"scalar_op": scalar_op_fn, "numba_vectorize": numba_vectorize}
elemwise_fn = numba_vectorize(scalar_op_fn)
elemwise_fn.py_scalar_func = scalar_op_fn
elemwise_fn_name = f"elemwise_{get_name_for_object(scalar_op_fn)}"
unique_names = unique_name_generator(
[elemwise_fn_name, "scalar_op", "scalar_op", "numba_vectorize"], suffix_sep="_"
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
input_signature_str = ", ".join(input_names)
elemwise_src = f"""
@numba_vectorize
def {elemwise_fn_name}({input_signature_str}):
return scalar_op({input_signature_str})
"""
elemwise_fn = compile_function_src(elemwise_src, elemwise_fn_name, global_env)
return elemwise_fn, input_names
return elemwise_fn
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn, input_names = create_vectorize_func(op, node, use_signature=False)
elemwise_fn = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
input_idx = op.inplace_pattern[0]
updated_input_name = input_names[input_idx]
......@@ -481,7 +470,7 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
)
return numba.njit(inline="always")(inplace_elemwise_fn)
return numba.njit(inplace_elemwise_fn)
return elemwise_fn
......@@ -626,9 +615,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
)
# TODO: Use `scalar_op_identity`?
elemwise_fn, *_ = create_vectorize_func(
op, dummy_node, use_signature=True, **kwargs
)
elemwise_fn = create_vectorize_func(op, dummy_node, use_signature=True, **kwargs)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
......
......@@ -588,7 +588,9 @@ def compile_function_src(src, function_name, global_env=None, local_env=None):
mod_code = compile(src, filename, mode="exec")
exec(mod_code, global_env, local_env)
return local_env[function_name]
res = local_env[function_name]
res.__source__ = src
return res
def get_name_for_object(x: Any):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论