提交 d0c7808a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify in2out and out2in construction

上级 02910f82
......@@ -15,7 +15,7 @@ import traceback
import warnings
from collections import OrderedDict, UserList, defaultdict, deque
from collections.abc import Iterable
from functools import reduce
from functools import partial, reduce
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -2172,34 +2172,11 @@ class TopoOptimizer(NavigatorOptimizer):
return getattr(self, "__name__", "<TopoOptimizer instance>")
def out2in(*local_opts, **kwargs):
"""
Uses the TopoOptimizer from the output nodes to input nodes of the graph.
"""
name = kwargs and kwargs.pop("name", None)
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts)
else:
(local_opts,) = local_opts
if not name:
name = local_opts.__name__
ret = TopoOptimizer(
local_opts,
order="out_to_in",
failure_callback=TopoOptimizer.warn_inplace,
**kwargs,
)
if name:
ret.__name__ = name
return ret
def topogroup_optimizer(order, *local_opts, name=None, **kwargs):
"""Apply `local_opts` from the input/output nodes to the output/input nodes of a graph.
def in2out(*local_opts, **kwargs):
"""
Uses the TopoOptimizer from the input nodes to output nodes of the graph.
This uses a combination of `LocalOptGroup` and `TopoOptimizer`.
"""
name = kwargs and kwargs.pop("name", None)
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts)
......@@ -2218,6 +2195,10 @@ def in2out(*local_opts, **kwargs):
return ret
in2out = partial(topogroup_optimizer, "in_to_out")
out2in = partial(topogroup_optimizer, "out_to_in")
class OpKeyOptimizer(NavigatorOptimizer):
r"""An optimizer that applies a `LocalOptimizer` to specific `Op`\s.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论