提交 b84ac43a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Call axis reducer functions directly in create_multiaxis_reducer

上级 3a3adaee
......@@ -538,19 +538,21 @@ def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input"
):
careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}"
careduce_axes_fns = ()
global_env = {}
to_reduce = reversed(sorted(axes))
careduce_lines_src = []
var_name = input_name
for i, axis in enumerate(to_reduce):
careduce_axes_fns += (
create_axis_reducer(reduce_fn, identity, axis - i, ndim, dtype),
careducer_axes_fn_name = f"careduce_axes_fn_{i}"
global_env[careducer_axes_fn_name] = create_axis_reducer(
reduce_fn, identity, axis - i, ndim, dtype
)
ndim -= 1
last_var_name = var_name
var_name = f"axis_{i}_res"
careduce_lines_src.append(
f"{var_name} = careduce_axes_fns[{i}]({last_var_name})"
f"{var_name} = {careducer_axes_fn_name}({last_var_name})"
)
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
......@@ -560,8 +562,6 @@ def {careduce_fn_name}({input_name}):
return {var_name}
"""
global_env = {"careduce_axes_fns": careduce_axes_fns}
careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env)
return careduce_fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论