提交 66974655 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename LocalOptGroup to SequentialNodeRewriter

上级 40296322
...@@ -1204,7 +1204,7 @@ class OpToRewriterTracker: ...@@ -1204,7 +1204,7 @@ class OpToRewriterTracker:
) )
class LocalOptGroup(NodeRewriter): class SequentialNodeRewriter(NodeRewriter):
r"""An optimizer that applies a list of `NodeRewriter`\s to a node. r"""An optimizer that applies a list of `NodeRewriter`\s to a node.
Attributes Attributes
...@@ -1272,7 +1272,7 @@ class LocalOptGroup(NodeRewriter): ...@@ -1272,7 +1272,7 @@ class LocalOptGroup(NodeRewriter):
return getattr( return getattr(
self, self,
"__name__", "__name__",
f"LocalOptGroup({','.join([str(o) for o in self.opts])})", f"{type(self).__name__}({','.join([str(o) for o in self.opts])})",
) )
def tracks(self): def tracks(self):
...@@ -1332,15 +1332,15 @@ class LocalOptGroup(NodeRewriter): ...@@ -1332,15 +1332,15 @@ class LocalOptGroup(NodeRewriter):
repl = new_repl repl = new_repl
node = new_vars[0].owner node = new_vars[0].owner
@staticmethod @classmethod
def print_profile(stream, prof, level=0): def print_profile(cls, stream, prof, level=0):
(time_opts, process_count, applied_true, node_created, profile) = prof (time_opts, process_count, applied_true, node_created, profile) = prof
if not profile: if not profile:
return return
blanc = " " * int(level) blanc = " " * int(level)
print(blanc, "LocalOptGroup", file=stream) print(blanc, f"{cls.__name__}", file=stream)
print(blanc, "---------------------", file=stream) print(blanc, "---------------------", file=stream)
count_opt = [] count_opt = []
not_used = [] not_used = []
...@@ -2064,7 +2064,7 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -2064,7 +2064,7 @@ class TopoOptimizer(NavigatorOptimizer):
print(blanc, " init io_toposort", io_t, file=stream) print(blanc, " init io_toposort", io_t, file=stream)
print(blanc, " loop time", loop_t, file=stream) print(blanc, " loop time", loop_t, file=stream)
print(blanc, " callback_time", callback_time, file=stream) print(blanc, " callback_time", callback_time, file=stream)
if isinstance(node_rewriter, LocalOptGroup): if isinstance(node_rewriter, SequentialNodeRewriter):
if node_rewriter.profile: if node_rewriter.profile:
node_rewriter.print_profile( node_rewriter.print_profile(
stream, stream,
...@@ -2089,14 +2089,14 @@ def topogroup_optimizer( ...@@ -2089,14 +2089,14 @@ def topogroup_optimizer(
failure_callback=TopoOptimizer.warn_inplace, failure_callback=TopoOptimizer.warn_inplace,
**kwargs, **kwargs,
): ):
"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph. r"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph.
This constructs `TopoOptimizer`s, and uses a `LocalOptGroup` when there's This constructs `TopoOptimizer`\s, and uses a `SequentialNodeRewriter` when there's
more than one entry in `node_rewriters`. more than one entry in `node_rewriters`.
""" """
if len(node_rewriters) > 1: if len(node_rewriters) > 1:
# Don't wrap it uselessly if their is only 1 optimization. # Don't wrap it uselessly if their is only 1 optimization.
node_rewriters = LocalOptGroup(*node_rewriters) node_rewriters = SequentialNodeRewriter(*node_rewriters)
else: else:
(node_rewriters,) = node_rewriters (node_rewriters,) = node_rewriters
if not name: if not name:
...@@ -3168,6 +3168,11 @@ DEPRECATED_NAMES = [ ...@@ -3168,6 +3168,11 @@ DEPRECATED_NAMES = [
"`LocalOptTracker` is deprecated: use `OpToRewriterTracker` instead.", "`LocalOptTracker` is deprecated: use `OpToRewriterTracker` instead.",
OpToRewriterTracker, OpToRewriterTracker,
), ),
(
"LocalOptGroup",
"`LocalOptGroup` is deprecated: use `SequentialNodeRewriter` instead.",
SequentialNodeRewriter,
),
] ]
......
...@@ -457,13 +457,13 @@ class SequenceDB(OptimizationDatabase): ...@@ -457,13 +457,13 @@ class SequenceDB(OptimizationDatabase):
class LocalGroupDB(SequenceDB): class LocalGroupDB(SequenceDB):
r"""A database that generates `NodeRewriter`\s of type `LocalOptGroup`.""" r"""A database that generates `NodeRewriter`\s of type `SequentialNodeRewriter`."""
def __init__( def __init__(
self, self,
apply_all_opts: bool = False, apply_all_opts: bool = False,
profile: bool = False, profile: bool = False,
node_rewriter=aesara_opt.LocalOptGroup, node_rewriter=aesara_opt.SequentialNodeRewriter,
): ):
super().__init__(failure_callback=None) super().__init__(failure_callback=None)
self.apply_all_opts = apply_all_opts self.apply_all_opts = apply_all_opts
......
...@@ -10,9 +10,9 @@ import aesara.scalar.basic as aes ...@@ -10,9 +10,9 @@ import aesara.scalar.basic as aes
import aesara.scalar.math as aes_math import aesara.scalar.math as aes_math
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import ( from aesara.graph.opt import (
LocalOptGroup,
NodeRewriter, NodeRewriter,
PatternSub, PatternSub,
SequentialNodeRewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
node_rewriter, node_rewriter,
...@@ -2117,7 +2117,7 @@ def local_add_specialize(fgraph, node): ...@@ -2117,7 +2117,7 @@ def local_add_specialize(fgraph, node):
mul_canonizer = in2out( mul_canonizer = in2out(
LocalOptGroup(local_mul_canonizer, local_fill_sink, apply_all_opts=True), SequentialNodeRewriter(local_mul_canonizer, local_fill_sink, apply_all_opts=True),
name="mul_canonizer_groups", name="mul_canonizer_groups",
) )
...@@ -2344,7 +2344,7 @@ def add_calculate(num, denum, aslist=False, out_type=None): ...@@ -2344,7 +2344,7 @@ def add_calculate(num, denum, aslist=False, out_type=None):
local_add_canonizer = AlgebraicCanonizer(add, sub, neg, add_calculate) local_add_canonizer = AlgebraicCanonizer(add, sub, neg, add_calculate)
add_canonizer = in2out( add_canonizer = in2out(
LocalOptGroup(local_add_canonizer, local_fill_sink, apply_all_opts=True), SequentialNodeRewriter(local_add_canonizer, local_fill_sink, apply_all_opts=True),
name="add_canonizer_group", name="add_canonizer_group",
) )
......
...@@ -7,12 +7,12 @@ from aesara.graph.fg import FunctionGraph ...@@ -7,12 +7,12 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
LocalOptGroup,
MergeOptimizer, MergeOptimizer,
OpKeyOptimizer, OpKeyOptimizer,
OpSub, OpSub,
OpToRewriterTracker, OpToRewriterTracker,
PatternSub, PatternSub,
SequentialNodeRewriter,
TopoOptimizer, TopoOptimizer,
in2out, in2out,
logging, logging,
...@@ -664,7 +664,7 @@ def test_patternsub_different_output_lengths(): ...@@ -664,7 +664,7 @@ def test_patternsub_different_output_lengths():
assert fgraph.outputs[0].owner.op == op1 assert fgraph.outputs[0].owner.op == op1
class TestLocalOptGroup: class TestSequentialNodeRewriter:
def test_optimizer_verbose(self, capsys): def test_optimizer_verbose(self, capsys):
x = MyVariable("x") x = MyVariable("x")
...@@ -685,7 +685,7 @@ class TestLocalOptGroup: ...@@ -685,7 +685,7 @@ class TestLocalOptGroup:
res = op2(x, *node.inputs[1:]) res = op2(x, *node.inputs[1:])
return [res] return [res]
opt_group = LocalOptGroup(local_opt_1, local_opt_2) opt_group = SequentialNodeRewriter(local_opt_1, local_opt_2)
with config.change_flags(optimizer_verbose=True): with config.change_flags(optimizer_verbose=True):
(new_res,) = opt_group.transform(fgraph, o1.owner) (new_res,) = opt_group.transform(fgraph, o1.owner)
......
...@@ -19,7 +19,7 @@ from aesara.configdefaults import config ...@@ -19,7 +19,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, equal_computations from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import ( from aesara.graph.opt import (
LocalOptGroup, SequentialNodeRewriter,
TopoOptimizer, TopoOptimizer,
check_stack_trace, check_stack_trace,
in2out, in2out,
...@@ -191,7 +191,7 @@ class TestGreedyDistribute: ...@@ -191,7 +191,7 @@ class TestGreedyDistribute:
g = FunctionGraph([a, b, c, d, x, y, z], [e]) g = FunctionGraph([a, b, c, d, x, y, z], [e])
mul_canonizer.optimize(g) mul_canonizer.optimize(g)
TopoOptimizer( TopoOptimizer(
LocalOptGroup(local_greedy_distributor), order="out_to_in" SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g) ).optimize(g)
assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))" assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))"
...@@ -200,7 +200,7 @@ class TestGreedyDistribute: ...@@ -200,7 +200,7 @@ class TestGreedyDistribute:
g = FunctionGraph([a, b, x], [e]) g = FunctionGraph([a, b, x], [e])
mul_canonizer.optimize(g) mul_canonizer.optimize(g)
TopoOptimizer( TopoOptimizer(
LocalOptGroup(local_greedy_distributor), order="out_to_in" SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g) ).optimize(g)
assert str(pprint(g.outputs[0])) == "(a + (b * x))" assert str(pprint(g.outputs[0])) == "(a + (b * x))"
...@@ -3053,7 +3053,7 @@ class TestLocalErfc: ...@@ -3053,7 +3053,7 @@ class TestLocalErfc:
fg = FunctionGraph(inputs, [no_match], clone=False) fg = FunctionGraph(inputs, [no_match], clone=False)
TopoOptimizer( TopoOptimizer(
LocalOptGroup(local_grad_log_erfc_neg), order="out_to_in" SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg) ).optimize(fg)
# Make sure that the graph hasn't been changed # Make sure that the graph hasn't been changed
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论