提交 0b3e6296 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use numba.vectorize for Elemwise conversions

上级 43829581
...@@ -49,7 +49,7 @@ def numba_funcify_FunctionGraph( ...@@ -49,7 +49,7 @@ def numba_funcify_FunctionGraph(
@numba_funcify.register(ScalarOp) @numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, **kwargs): def numba_funcify_ScalarOp(op, node, **kwargs):
scalar_func_name = op.nfunc_spec[0] scalar_func_name = op.nfunc_spec[0]
...@@ -64,23 +64,40 @@ def numba_funcify_ScalarOp(op, **kwargs): ...@@ -64,23 +64,40 @@ def numba_funcify_ScalarOp(op, **kwargs):
else: else:
scalar_func = getattr(func_package, scalar_func_name) scalar_func = getattr(func_package, scalar_func_name)
@numba.njit input_names = ", ".join([v.auto_name for v in node.inputs])
def scalar_op(*args):
return scalar_func(*args) global_env = {"scalar_func": scalar_func}
return scalar_op scalar_op_src = f"""
def scalar_op({input_names}):
return scalar_func({input_names})
"""
scalar_op_fn = compile_function_src(scalar_op_src, "scalar_op", global_env)
return numba.njit(scalar_op_fn)
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, **kwargs): def numba_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs)
# TODO: Vectorize this
return numba_funcify(scalar_op) input_names = ", ".join([v.auto_name for v in node.inputs])
global_env = {"scalar_op": scalar_op_fn, "vectorize": numba.vectorize}
elemwise_src = f"""
@vectorize
def elemwise({input_names}):
return scalar_op({input_names})
"""
elemwise_fn = compile_function_src(elemwise_src, "elemwise", global_env)
return elemwise_fn
@numba_funcify.register(Composite) @numba_funcify.register(Composite)
def numba_funcify_Composite(op, vectorize=True, **kwargs): def numba_funcify_Composite(op, vectorize=True, **kwargs):
numba_impl = numba.njit(numba_funcify(op.fgraph)) numba_impl = numba.njit(numba_funcify(op.fgraph, **kwargs))
@numba.njit @numba.njit
def composite(*args): def composite(*args):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论