提交 cac73bae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make topogroup_optimizer use order argument

上级 be222f0c
...@@ -2038,10 +2038,13 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -2038,10 +2038,13 @@ class TopoOptimizer(NavigatorOptimizer):
return getattr(self, "__name__", "<TopoOptimizer instance>") return getattr(self, "__name__", "<TopoOptimizer instance>")
def topogroup_optimizer(order, *local_opts, name=None, **kwargs): def topogroup_optimizer(
order, *local_opts, name=None, failure_callback=TopoOptimizer.warn_inplace, **kwargs
):
"""Apply `local_opts` from the input/output nodes to the output/input nodes of a graph. """Apply `local_opts` from the input/output nodes to the output/input nodes of a graph.
This uses a combination of `LocalOptGroup` and `TopoOptimizer`. This constructs `TopoOptimizer`s, and uses a `LocalOptGroup` when there's
more than one entry in `local_opts`.
""" """
if len(local_opts) > 1: if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization. # Don't wrap it uselessly if their is only 1 optimization.
...@@ -2052,8 +2055,8 @@ def topogroup_optimizer(order, *local_opts, name=None, **kwargs): ...@@ -2052,8 +2055,8 @@ def topogroup_optimizer(order, *local_opts, name=None, **kwargs):
name = local_opts.__name__ name = local_opts.__name__
ret = TopoOptimizer( ret = TopoOptimizer(
local_opts, local_opts,
order="in_to_out", order=order,
failure_callback=TopoOptimizer.warn_inplace, failure_callback=failure_callback,
**kwargs, **kwargs,
) )
if name: if name:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论