提交 40296322 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename LocalOptTracker to OpToRewriterTracker

上级 2929f425
...@@ -1152,8 +1152,8 @@ def node_rewriter( ...@@ -1152,8 +1152,8 @@ def node_rewriter(
return decorator return decorator
class LocalOptTracker: class OpToRewriterTracker:
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance.""" r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance."""
def __init__(self): def __init__(self):
self.tracked_instances: Dict[Op, List[NodeRewriter]] = {} self.tracked_instances: Dict[Op, List[NodeRewriter]] = {}
...@@ -1256,7 +1256,7 @@ class LocalOptGroup(NodeRewriter): ...@@ -1256,7 +1256,7 @@ class LocalOptGroup(NodeRewriter):
self.applied_true: Dict[Rewriter, int] = {} self.applied_true: Dict[Rewriter, int] = {}
self.node_created: Dict[Rewriter, int] = {} self.node_created: Dict[Rewriter, int] = {}
self.tracker = LocalOptTracker() self.tracker = OpToRewriterTracker()
for o in self.opts: for o in self.opts:
...@@ -2246,7 +2246,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2246,7 +2246,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.global_optimizers: List[GraphRewriter] = [] self.global_optimizers: List[GraphRewriter] = []
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.node_tracker = LocalOptTracker() self.node_tracker = OpToRewriterTracker()
for opt in optimizers: for opt in optimizers:
if isinstance(opt, NodeRewriter): if isinstance(opt, NodeRewriter):
...@@ -3163,6 +3163,11 @@ DEPRECATED_NAMES = [ ...@@ -3163,6 +3163,11 @@ DEPRECATED_NAMES = [
"`FromFunctionLocalOptimizer` is deprecated: use `FromFunctionNodeRewriter` instead.", "`FromFunctionLocalOptimizer` is deprecated: use `FromFunctionNodeRewriter` instead.",
FromFunctionNodeRewriter, FromFunctionNodeRewriter,
), ),
(
"LocalOptTracker",
"`LocalOptTracker` is deprecated: use `OpToRewriterTracker` instead.",
OpToRewriterTracker,
),
] ]
......
...@@ -8,10 +8,10 @@ from aesara.graph.op import Op ...@@ -8,10 +8,10 @@ from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
LocalOptGroup, LocalOptGroup,
LocalOptTracker,
MergeOptimizer, MergeOptimizer,
OpKeyOptimizer, OpKeyOptimizer,
OpSub, OpSub,
OpToRewriterTracker,
PatternSub, PatternSub,
TopoOptimizer, TopoOptimizer,
in2out, in2out,
...@@ -766,7 +766,7 @@ def test_node_rewriter(): ...@@ -766,7 +766,7 @@ def test_node_rewriter():
assert hits[0] == 2 assert hits[0] == 2
def test_TrackingNodeRewriter(): def test_OpToRewriterTracker():
@node_rewriter(None) @node_rewriter(None)
def local_opt_1(fgraph, node): def local_opt_1(fgraph, node):
pass pass
...@@ -787,7 +787,7 @@ def test_TrackingNodeRewriter(): ...@@ -787,7 +787,7 @@ def test_TrackingNodeRewriter():
def local_opt_5(fgraph, node): def local_opt_5(fgraph, node):
pass pass
tracker = LocalOptTracker() tracker = OpToRewriterTracker()
tracker.add_tracker(local_opt_1) tracker.add_tracker(local_opt_1)
tracker.add_tracker(local_opt_2) tracker.add_tracker(local_opt_2)
tracker.add_tracker(local_opt_3) tracker.add_tracker(local_opt_3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论