提交 d6fecc46 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary function call when there's only one reduce axis

上级 5213962b
......@@ -581,6 +581,9 @@ def create_axis_reducer(
def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input"
):
if len(axes) == 1:
return create_axis_reducer(reduce_fn, identity, axes[0], ndim, dtype)
careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}"
global_env = {}
to_reduce = reversed(sorted(axes))
......@@ -607,7 +610,7 @@ def {careduce_fn_name}({input_name}):
"""
careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env)
return careduce_fn
return numba.njit(careduce_fn)
@numba_funcify.register(CAReduce)
......@@ -645,7 +648,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
elemwise_fn, scalar_op_identity, axes, ndim, np_acc_dtype, input_name=input_name
)
return numba.njit(careduce_fn)
return careduce_fn
@numba_funcify.register(Composite)
......@@ -1635,8 +1638,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
# work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max = numba.njit(
create_multiaxis_reducer(np.maximum, -np.inf, axes, x_ndim, x_dtype)
reduce_max = create_multiaxis_reducer(
np.maximum, -np.inf, axes, x_ndim, x_dtype
)
reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论