提交 b84ac43a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Call axis reducer functions directly in create_multiaxis_reducer

上级 3a3adaee
...@@ -538,19 +538,21 @@ def create_multiaxis_reducer( ...@@ -538,19 +538,21 @@ def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input" reduce_fn, identity, axes, ndim, dtype, input_name="input"
): ):
careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}" careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}"
careduce_axes_fns = () global_env = {}
to_reduce = reversed(sorted(axes)) to_reduce = reversed(sorted(axes))
careduce_lines_src = [] careduce_lines_src = []
var_name = input_name var_name = input_name
for i, axis in enumerate(to_reduce): for i, axis in enumerate(to_reduce):
careduce_axes_fns += ( careducer_axes_fn_name = f"careduce_axes_fn_{i}"
create_axis_reducer(reduce_fn, identity, axis - i, ndim, dtype), global_env[careducer_axes_fn_name] = create_axis_reducer(
reduce_fn, identity, axis - i, ndim, dtype
) )
ndim -= 1 ndim -= 1
last_var_name = var_name last_var_name = var_name
var_name = f"axis_{i}_res" var_name = f"axis_{i}_res"
careduce_lines_src.append( careduce_lines_src.append(
f"{var_name} = careduce_axes_fns[{i}]({last_var_name})" f"{var_name} = {careducer_axes_fn_name}({last_var_name})"
) )
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
...@@ -560,8 +562,6 @@ def {careduce_fn_name}({input_name}): ...@@ -560,8 +562,6 @@ def {careduce_fn_name}({input_name}):
return {var_name} return {var_name}
""" """
global_env = {"careduce_axes_fns": careduce_axes_fns}
careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env) careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env)
return careduce_fn return careduce_fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论