提交 e6635af8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename OpKeyOptimizer to OpKeyGraphRewriter

上级 9546f164
...@@ -2143,14 +2143,14 @@ in2out = partial(walking_rewriter, "in_to_out") ...@@ -2143,14 +2143,14 @@ in2out = partial(walking_rewriter, "in_to_out")
out2in = partial(walking_rewriter, "out_to_in") out2in = partial(walking_rewriter, "out_to_in")
class OpKeyOptimizer(NodeProcessingGraphRewriter): class OpKeyGraphRewriter(NodeProcessingGraphRewriter):
r"""An optimizer that applies a `NodeRewriter` to specific `Op`\s. r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s.
The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
as a list of `Op`\s or a single `Op`), and discovered within a as a list of `Op`\s or a single `Op`), and discovered within a
`FunctionGraph` using the `NodeFinder` `Feature`. `FunctionGraph` using the `NodeFinder` `Feature`.
This is similar to the ``tracks`` feature used by other optimizers. This is similar to the `Op`-based tracking feature used by other rewriters.
""" """
...@@ -3230,6 +3230,11 @@ DEPRECATED_NAMES = [ ...@@ -3230,6 +3230,11 @@ DEPRECATED_NAMES = [
"`topogroup_optimizer` is deprecated: use `walking_rewriter` instead.", "`topogroup_optimizer` is deprecated: use `walking_rewriter` instead.",
walking_rewriter, walking_rewriter,
), ),
(
"OpKeyOptimizer",
"`OpKeyOptimizer` is deprecated: use `OpKeyGraphRewriter` instead.",
OpKeyGraphRewriter,
),
] ]
......
...@@ -13,7 +13,7 @@ from aesara.compile.io import In, Out ...@@ -13,7 +13,7 @@ from aesara.compile.io import In, Out
from aesara.compile.mode import Mode, get_default_mode from aesara.compile.mode import Mode, get_default_mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.opt import OpKeyOptimizer, PatternNodeRewriter from aesara.graph.opt import OpKeyGraphRewriter, PatternNodeRewriter
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.link.vm import VMLinker from aesara.link.vm import VMLinker
from aesara.tensor.math import dot from aesara.tensor.math import dot
...@@ -35,7 +35,7 @@ from aesara.utils import exc_message ...@@ -35,7 +35,7 @@ from aesara.utils import exc_message
def PatternOptimizer(p1, p2, ign=True): def PatternOptimizer(p1, p2, ign=True):
return OpKeyOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestFunction: class TestFunction:
......
...@@ -10,7 +10,7 @@ from aesara.graph.fg import FunctionGraph ...@@ -10,7 +10,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
OpKeyOptimizer, OpKeyGraphRewriter,
PatternNodeRewriter, PatternNodeRewriter,
SubstitutionNodeRewriter, SubstitutionNodeRewriter,
WalkingGraphRewriter, WalkingGraphRewriter,
...@@ -21,7 +21,7 @@ from tests.unittest_tools import assertFailure_fast ...@@ -21,7 +21,7 @@ from tests.unittest_tools import assertFailure_fast
def PatternOptimizer(p1, p2, ign=True): def PatternOptimizer(p1, p2, ign=True):
return OpKeyOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def TopoSubstitutionNodeRewriter( def TopoSubstitutionNodeRewriter(
......
...@@ -8,7 +8,7 @@ from aesara.graph.op import Op ...@@ -8,7 +8,7 @@ from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
MergeOptimizer, MergeOptimizer,
OpKeyOptimizer, OpKeyGraphRewriter,
OpToRewriterTracker, OpToRewriterTracker,
PatternNodeRewriter, PatternNodeRewriter,
SequentialNodeRewriter, SequentialNodeRewriter,
...@@ -51,7 +51,7 @@ class AssertNoChanges(Feature): ...@@ -51,7 +51,7 @@ class AssertNoChanges(Feature):
def PatternOptimizer(p1, p2, ign=False): def PatternOptimizer(p1, p2, ign=False):
return OpKeyOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def TopoPatternOptimizer(p1, p2, ign=True): def TopoPatternOptimizer(p1, p2, ign=True):
...@@ -224,7 +224,7 @@ class TestPatternOptimizer: ...@@ -224,7 +224,7 @@ class TestPatternOptimizer:
def KeyedSubstitutionNodeRewriter(op1, op2): def KeyedSubstitutionNodeRewriter(op1, op2):
return OpKeyOptimizer(SubstitutionNodeRewriter(op1, op2)) return OpKeyGraphRewriter(SubstitutionNodeRewriter(op1, op2))
class TestSubstitutionNodeRewriter: class TestSubstitutionNodeRewriter:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论