提交 d0c7808a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify in2out and out2in construction

上级 02910f82
...@@ -15,7 +15,7 @@ import traceback ...@@ -15,7 +15,7 @@ import traceback
import warnings import warnings
from collections import OrderedDict, UserList, defaultdict, deque from collections import OrderedDict, UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import reduce from functools import partial, reduce
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -2172,34 +2172,11 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -2172,34 +2172,11 @@ class TopoOptimizer(NavigatorOptimizer):
return getattr(self, "__name__", "<TopoOptimizer instance>") return getattr(self, "__name__", "<TopoOptimizer instance>")
def out2in(*local_opts, **kwargs): def topogroup_optimizer(order, *local_opts, name=None, **kwargs):
""" """Apply `local_opts` from the input/output nodes to the output/input nodes of a graph.
Uses the TopoOptimizer from the output nodes to input nodes of the graph.
"""
name = kwargs and kwargs.pop("name", None)
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts)
else:
(local_opts,) = local_opts
if not name:
name = local_opts.__name__
ret = TopoOptimizer(
local_opts,
order="out_to_in",
failure_callback=TopoOptimizer.warn_inplace,
**kwargs,
)
if name:
ret.__name__ = name
return ret
def in2out(*local_opts, **kwargs): This uses a combination of `LocalOptGroup` and `TopoOptimizer`.
"""
Uses the TopoOptimizer from the input nodes to output nodes of the graph.
""" """
name = kwargs and kwargs.pop("name", None)
if len(local_opts) > 1: if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization. # Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts) local_opts = LocalOptGroup(*local_opts)
...@@ -2218,6 +2195,10 @@ def in2out(*local_opts, **kwargs): ...@@ -2218,6 +2195,10 @@ def in2out(*local_opts, **kwargs):
return ret return ret
in2out = partial(topogroup_optimizer, "in_to_out")
out2in = partial(topogroup_optimizer, "out_to_in")
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
r"""An optimizer that applies a `LocalOptimizer` to specific `Op`\s. r"""An optimizer that applies a `LocalOptimizer` to specific `Op`\s.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论