提交 66974655 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename LocalOptGroup to SequentialNodeRewriter

上级 40296322
......@@ -1204,7 +1204,7 @@ class OpToRewriterTracker:
)
class LocalOptGroup(NodeRewriter):
class SequentialNodeRewriter(NodeRewriter):
r"""An optimizer that applies a list of `NodeRewriter`\s to a node.
Attributes
......@@ -1272,7 +1272,7 @@ class LocalOptGroup(NodeRewriter):
return getattr(
self,
"__name__",
f"LocalOptGroup({','.join([str(o) for o in self.opts])})",
f"{type(self).__name__}({','.join([str(o) for o in self.opts])})",
)
def tracks(self):
......@@ -1332,15 +1332,15 @@ class LocalOptGroup(NodeRewriter):
repl = new_repl
node = new_vars[0].owner
@staticmethod
def print_profile(stream, prof, level=0):
@classmethod
def print_profile(cls, stream, prof, level=0):
(time_opts, process_count, applied_true, node_created, profile) = prof
if not profile:
return
blanc = " " * int(level)
print(blanc, "LocalOptGroup", file=stream)
print(blanc, f"{cls.__name__}", file=stream)
print(blanc, "---------------------", file=stream)
count_opt = []
not_used = []
......@@ -2064,7 +2064,7 @@ class TopoOptimizer(NavigatorOptimizer):
print(blanc, " init io_toposort", io_t, file=stream)
print(blanc, " loop time", loop_t, file=stream)
print(blanc, " callback_time", callback_time, file=stream)
if isinstance(node_rewriter, LocalOptGroup):
if isinstance(node_rewriter, SequentialNodeRewriter):
if node_rewriter.profile:
node_rewriter.print_profile(
stream,
......@@ -2089,14 +2089,14 @@ def topogroup_optimizer(
failure_callback=TopoOptimizer.warn_inplace,
**kwargs,
):
"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph.
r"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph.
This constructs `TopoOptimizer`s, and uses a `LocalOptGroup` when there's
This constructs `TopoOptimizer`\s, and uses a `SequentialNodeRewriter` when there's
more than one entry in `node_rewriters`.
"""
if len(node_rewriters) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
node_rewriters = LocalOptGroup(*node_rewriters)
node_rewriters = SequentialNodeRewriter(*node_rewriters)
else:
(node_rewriters,) = node_rewriters
if not name:
......@@ -3168,6 +3168,11 @@ DEPRECATED_NAMES = [
"`LocalOptTracker` is deprecated: use `OpToRewriterTracker` instead.",
OpToRewriterTracker,
),
(
"LocalOptGroup",
"`LocalOptGroup` is deprecated: use `SequentialNodeRewriter` instead.",
SequentialNodeRewriter,
),
]
......
......@@ -457,13 +457,13 @@ class SequenceDB(OptimizationDatabase):
class LocalGroupDB(SequenceDB):
r"""A database that generates `NodeRewriter`\s of type `LocalOptGroup`."""
r"""A database that generates `NodeRewriter`\s of type `SequentialNodeRewriter`."""
def __init__(
self,
apply_all_opts: bool = False,
profile: bool = False,
node_rewriter=aesara_opt.LocalOptGroup,
node_rewriter=aesara_opt.SequentialNodeRewriter,
):
super().__init__(failure_callback=None)
self.apply_all_opts = apply_all_opts
......
......@@ -10,9 +10,9 @@ import aesara.scalar.basic as aes
import aesara.scalar.math as aes_math
from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import (
LocalOptGroup,
NodeRewriter,
PatternSub,
SequentialNodeRewriter,
copy_stack_trace,
in2out,
node_rewriter,
......@@ -2117,7 +2117,7 @@ def local_add_specialize(fgraph, node):
mul_canonizer = in2out(
LocalOptGroup(local_mul_canonizer, local_fill_sink, apply_all_opts=True),
SequentialNodeRewriter(local_mul_canonizer, local_fill_sink, apply_all_opts=True),
name="mul_canonizer_groups",
)
......@@ -2344,7 +2344,7 @@ def add_calculate(num, denum, aslist=False, out_type=None):
local_add_canonizer = AlgebraicCanonizer(add, sub, neg, add_calculate)
add_canonizer = in2out(
LocalOptGroup(local_add_canonizer, local_fill_sink, apply_all_opts=True),
SequentialNodeRewriter(local_add_canonizer, local_fill_sink, apply_all_opts=True),
name="add_canonizer_group",
)
......
......@@ -7,12 +7,12 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import (
EquilibriumOptimizer,
LocalOptGroup,
MergeOptimizer,
OpKeyOptimizer,
OpSub,
OpToRewriterTracker,
PatternSub,
SequentialNodeRewriter,
TopoOptimizer,
in2out,
logging,
......@@ -664,7 +664,7 @@ def test_patternsub_different_output_lengths():
assert fgraph.outputs[0].owner.op == op1
class TestLocalOptGroup:
class TestSequentialNodeRewriter:
def test_optimizer_verbose(self, capsys):
x = MyVariable("x")
......@@ -685,7 +685,7 @@ class TestLocalOptGroup:
res = op2(x, *node.inputs[1:])
return [res]
opt_group = LocalOptGroup(local_opt_1, local_opt_2)
opt_group = SequentialNodeRewriter(local_opt_1, local_opt_2)
with config.change_flags(optimizer_verbose=True):
(new_res,) = opt_group.transform(fgraph, o1.owner)
......
......@@ -19,7 +19,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import (
LocalOptGroup,
SequentialNodeRewriter,
TopoOptimizer,
check_stack_trace,
in2out,
......@@ -191,7 +191,7 @@ class TestGreedyDistribute:
g = FunctionGraph([a, b, c, d, x, y, z], [e])
mul_canonizer.optimize(g)
TopoOptimizer(
LocalOptGroup(local_greedy_distributor), order="out_to_in"
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g)
assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))"
......@@ -200,7 +200,7 @@ class TestGreedyDistribute:
g = FunctionGraph([a, b, x], [e])
mul_canonizer.optimize(g)
TopoOptimizer(
LocalOptGroup(local_greedy_distributor), order="out_to_in"
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g)
assert str(pprint(g.outputs[0])) == "(a + (b * x))"
......@@ -3053,7 +3053,7 @@ class TestLocalErfc:
fg = FunctionGraph(inputs, [no_match], clone=False)
TopoOptimizer(
LocalOptGroup(local_grad_log_erfc_neg), order="out_to_in"
SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg)
# Make sure that the graph hasn't been changed
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论