提交 3ce33a3d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename TopoOptimizer to WalkingGraphRewriter

上级 fbe8c876
......@@ -1997,8 +1997,8 @@ class NodeProcessingGraphRewriter(GraphRewriter):
)
class TopoOptimizer(NodeProcessingGraphRewriter):
"""An optimizer that applies a single `NodeRewriter` to each node in topological order (or reverse)."""
class WalkingGraphRewriter(NodeProcessingGraphRewriter):
"""A rewriter that applies a single `NodeRewriter` to each node in topological order (or reverse)."""
def __init__(
self,
......@@ -2057,11 +2057,11 @@ class TopoOptimizer(NodeProcessingGraphRewriter):
self.node_rewriter,
)
@staticmethod
def print_profile(stream, prof, level=0):
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
if prof is None: # Happen as merge_profile() isn't implemented
print(blanc, "TopoOptimizer merge_profile not implemented", file=stream)
print(blanc, f"{cls.__name__} merge_profile not implemented", file=stream)
return
(
......@@ -2077,7 +2077,7 @@ class TopoOptimizer(NodeProcessingGraphRewriter):
print(
blanc,
"TopoOptimizer ",
f"{cls.__name__} ",
getattr(opt, "name", getattr(opt, "__name__", "")),
file=stream,
)
......@@ -2106,19 +2106,19 @@ class TopoOptimizer(NodeProcessingGraphRewriter):
)
def __str__(self):
return getattr(self, "__name__", "<TopoOptimizer instance>")
return getattr(self, "__name__", super().__str__())
def topogroup_optimizer(
order,
*node_rewriters,
name=None,
failure_callback=TopoOptimizer.warn_inplace,
failure_callback=WalkingGraphRewriter.warn_inplace,
**kwargs,
):
r"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph.
This constructs `TopoOptimizer`\s, and uses a `SequentialNodeRewriter` when there's
This constructs `WalkingGraphRewriter`\s, and uses a `SequentialNodeRewriter` when there's
more than one entry in `node_rewriters`.
"""
if len(node_rewriters) > 1:
......@@ -2128,7 +2128,7 @@ def topogroup_optimizer(
(node_rewriters,) = node_rewriters
if not name:
name = node_rewriters.__name__
ret = TopoOptimizer(
ret = WalkingGraphRewriter(
node_rewriters,
order=order,
failure_callback=failure_callback,
......@@ -3220,6 +3220,11 @@ DEPRECATED_NAMES = [
"`NavigatorOptimizer` is deprecated: use `NodeProcessingGraphRewriter` instead.",
NodeProcessingGraphRewriter,
),
(
"TopoOptimizer",
"`TopoOptimizer` is deprecated: use `WalkingGraphRewriter` instead.",
WalkingGraphRewriter,
),
]
......
......@@ -483,7 +483,7 @@ class LocalGroupDB(SequenceDB):
class TopoDB(OptimizationDatabase):
"""Generate a `GraphRewriter` of type TopoOptimizer."""
"""Generate a `GraphRewriter` of type `WalkingGraphRewriter`."""
def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
......@@ -495,7 +495,7 @@ class TopoDB(OptimizationDatabase):
self.failure_callback = failure_callback
def query(self, *tags, **kwtags):
return aesara_opt.TopoOptimizer(
return aesara_opt.WalkingGraphRewriter(
self.db.query(*tags, **kwtags),
self.order,
self.ignore_newtrees,
......
......@@ -4,7 +4,7 @@ import aesara
import aesara.scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.opt import PatternNodeRewriter, TopoOptimizer, node_rewriter
from aesara.graph.opt import PatternNodeRewriter, WalkingGraphRewriter, node_rewriter
from aesara.link.c.op import COp, _NoPythonCOp
from aesara.misc.safe_asarray import _asarray
from aesara.sparse import basic as sparse
......@@ -68,7 +68,9 @@ def local_inplace_remove0(fgraph, node):
aesara.compile.optdb.register(
"local_inplace_remove0",
TopoOptimizer(local_inplace_remove0, failure_callback=TopoOptimizer.warn_inplace),
WalkingGraphRewriter(
local_inplace_remove0, failure_callback=WalkingGraphRewriter.warn_inplace
),
"fast_run",
"inplace",
position=60,
......@@ -207,8 +209,8 @@ def local_inplace_addsd_ccode(fgraph, node):
aesara.compile.optdb.register(
"local_inplace_addsd_ccode",
TopoOptimizer(
local_inplace_addsd_ccode, failure_callback=TopoOptimizer.warn_inplace
WalkingGraphRewriter(
local_inplace_addsd_ccode, failure_callback=WalkingGraphRewriter.warn_inplace
),
"fast_run",
"inplace",
......@@ -240,7 +242,7 @@ def local_addsd_ccode(fgraph, node):
aesara.compile.optdb.register(
"local_addsd_ccode",
TopoOptimizer(local_addsd_ccode),
WalkingGraphRewriter(local_addsd_ccode),
# Must be after local_inplace_addsd_ccode at 60
"fast_run",
position=61,
......
......@@ -620,7 +620,7 @@ class AlgebraicCanonizer(NodeRewriter):
r"""A `Rewriter` that rewrites algebraic expressions.
The variable is a `node_rewriter`. It is best used
with a `TopoOptimizer` in in-to-out order.
with a `WalkingGraphRewriter` in in-to-out order.
Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)``
......
......@@ -3,7 +3,7 @@ from aesara import tensor as at
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, node_rewriter
from aesara.graph.opt import WalkingGraphRewriter, copy_stack_trace, node_rewriter
def get_diagonal_subtensor_view(x, i0, i1):
......@@ -312,8 +312,9 @@ def local_inplace_DiagonalSubtensor(fgraph, node):
aesara.compile.optdb.register(
"local_inplace_DiagonalSubtensor",
TopoOptimizer(
local_inplace_DiagonalSubtensor, failure_callback=TopoOptimizer.warn_inplace
WalkingGraphRewriter(
local_inplace_DiagonalSubtensor,
failure_callback=WalkingGraphRewriter.warn_inplace,
),
"fast_run",
"inplace",
......
......@@ -8,7 +8,7 @@ from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.graph.opt import (
MetaNodeRewriterSkip,
TopoOptimizer,
WalkingGraphRewriter,
copy_stack_trace,
in2out,
node_rewriter,
......@@ -51,8 +51,9 @@ def local_inplace_sparse_block_gemv(fgraph, node):
compile.optdb.register(
"local_inplace_sparse_block_gemv",
TopoOptimizer(
local_inplace_sparse_block_gemv, failure_callback=TopoOptimizer.warn_inplace
WalkingGraphRewriter(
local_inplace_sparse_block_gemv,
failure_callback=WalkingGraphRewriter.warn_inplace,
),
"fast_run",
"inplace",
......@@ -74,9 +75,9 @@ def local_inplace_sparse_block_outer(fgraph, node):
compile.optdb.register(
"local_inplace_sparse_block_outer",
TopoOptimizer(
WalkingGraphRewriter(
local_inplace_sparse_block_outer,
failure_callback=TopoOptimizer.warn_inplace,
failure_callback=WalkingGraphRewriter.warn_inplace,
),
"fast_run",
"inplace",
......
......@@ -7,7 +7,12 @@ import aesara
import aesara.scalar.basic as aes
from aesara import compile
from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, node_rewriter
from aesara.graph.opt import (
WalkingGraphRewriter,
copy_stack_trace,
in2out,
node_rewriter,
)
from aesara.raise_op import Assert
from aesara.tensor.basic import (
Alloc,
......@@ -1200,9 +1205,9 @@ def local_IncSubtensor_serialize(fgraph, node):
# print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
# We register it in a TopoOptimizer inside the canonizer EQ optimizer.
# We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer.
# Otherwise in some cases it was making the EQ optimizer use 45. In
# the TopoOptimizer, the EQ only use 5 passes.
# the WalkingGraphRewriter, the EQ only use 5 passes.
compile.optdb.register(
"pre_local_IncSubtensor_serialize",
in2out(local_IncSubtensor_serialize),
......@@ -1240,8 +1245,8 @@ def local_inplace_setsubtensor(fgraph, node):
compile.optdb.register(
"local_inplace_setsubtensor",
TopoOptimizer(
local_inplace_setsubtensor, failure_callback=TopoOptimizer.warn_inplace
WalkingGraphRewriter(
local_inplace_setsubtensor, failure_callback=WalkingGraphRewriter.warn_inplace
),
"fast_run",
"inplace",
......@@ -1261,8 +1266,9 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
compile.optdb.register(
"local_inplace_AdvancedIncSubtensor1",
TopoOptimizer(
local_inplace_AdvancedIncSubtensor1, failure_callback=TopoOptimizer.warn_inplace
WalkingGraphRewriter(
local_inplace_AdvancedIncSubtensor1,
failure_callback=WalkingGraphRewriter.warn_inplace,
),
"fast_run",
"inplace",
......@@ -1286,8 +1292,9 @@ def local_inplace_AdvancedIncSubtensor(fgraph, node):
compile.optdb.register(
"local_inplace_AdvancedIncSubtensor",
TopoOptimizer(
local_inplace_AdvancedIncSubtensor, failure_callback=TopoOptimizer.warn_inplace
WalkingGraphRewriter(
local_inplace_AdvancedIncSubtensor,
failure_callback=WalkingGraphRewriter.warn_inplace,
),
"fast_run",
"inplace",
......
from aesara.compile import optdb
from aesara.graph.opt import TopoOptimizer, node_rewriter
from aesara.graph.opt import WalkingGraphRewriter, node_rewriter
from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse
......@@ -18,7 +18,9 @@ def typed_list_inplace_opt(fgraph, node):
optdb.register(
"typed_list_inplace_opt",
TopoOptimizer(typed_list_inplace_opt, failure_callback=TopoOptimizer.warn_inplace),
WalkingGraphRewriter(
typed_list_inplace_opt, failure_callback=WalkingGraphRewriter.warn_inplace
),
"fast_run",
"inplace",
position=60,
......
......@@ -264,9 +264,9 @@ subset of them) and applies one or several local optimizers.
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify = aesara.graph.opt.TopoOptimizer(local_simplify)
>>> simplify = aesara.graph.opt.WalkingGraphRewriter(local_simplify)
>>> simplify.optimize(e)
(<aesara.graph.opt.TopoOptimizer object at 0x...>, 1, 5, 3, ..., ..., ...)
(<aesara.graph.opt.WalkingGraphRewriter object at 0x...>, 1, 5, 3, ..., ..., ...)
>>> e
FunctionGraph(add(z, mul(x, true_div(z, x))))
......@@ -962,14 +962,14 @@ This will output something like this:
time_toposort 0.00311398506165
validate_time 4.60147857666e-05
callback_time 0.00174236297607
0.004569s - ('local_dot_to_dot22', 'TopoOptimizer', 0) - 0.000s
TopoOptimizer
0.004569s - ('local_dot_to_dot22', 'WalkingGraphRewriter', 0) - 0.000s
WalkingGraphRewriter
nb_node (start, end, changed) (81, 81, 5)
init io_toposort 0.00139284133911
loop time 0.00312399864197
callback_time 0.00172805786133
0.002283s - ('local_dot22_to_dot22scalar', 'TopoOptimizer', 2) - 0.000s
TopoOptimizer
0.002283s - ('local_dot22_to_dot22scalar', 'WalkingGraphRewriter', 2) - 0.000s
WalkingGraphRewriter
nb_node (start, end, changed) (80, 80, 0)
init io_toposort 0.00171804428101
loop time 0.000502109527588
......@@ -982,14 +982,14 @@ This will output something like this:
time in local optimizers 0.000s
time in global optimizers 0.000s
0 - 0.002s 0 (0.000s in global opts, 0.001s io_toposort) - 80 nodes -
0.002227s - ('use_c_blas', 'TopoOptimizer', 4) - 0.000s
TopoOptimizer
0.002227s - ('use_c_blas', 'WalkingGraphRewriter', 4) - 0.000s
WalkingGraphRewriter
nb_node (start, end, changed) (80, 80, 0)
init io_toposort 0.0014750957489
loop time 0.00068998336792
callback_time 0.0
0.001632s - ('use_scipy_ger', 'TopoOptimizer', 5) - 0.000s
TopoOptimizer
0.001632s - ('use_scipy_ger', 'WalkingGraphRewriter', 5) - 0.000s
WalkingGraphRewriter
nb_node (start, end, changed) (80, 80, 0)
init io_toposort 0.00138401985168
loop time 0.000202178955078
......
......@@ -13,7 +13,7 @@ from aesara.graph.opt import (
OpKeyOptimizer,
PatternNodeRewriter,
SubstitutionNodeRewriter,
TopoOptimizer,
WalkingGraphRewriter,
)
from aesara.graph.type import Type
from aesara.graph.utils import InconsistencyError
......@@ -27,7 +27,7 @@ def PatternOptimizer(p1, p2, ign=True):
def TopoSubstitutionNodeRewriter(
op1, op2, fail=NodeProcessingGraphRewriter.warn_ignore, ign=True
):
return TopoOptimizer(
return WalkingGraphRewriter(
SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail
)
......
......@@ -13,7 +13,7 @@ from aesara.graph.opt import (
PatternNodeRewriter,
SequentialNodeRewriter,
SubstitutionNodeRewriter,
TopoOptimizer,
WalkingGraphRewriter,
in2out,
logging,
node_rewriter,
......@@ -55,7 +55,7 @@ def PatternOptimizer(p1, p2, ign=False):
def TopoPatternOptimizer(p1, p2, ign=True):
return TopoOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestPatternOptimizer:
......@@ -148,7 +148,7 @@ class TestPatternOptimizer:
assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
def test_ambiguous(self):
# this test should always work with TopoOptimizer and the
# this test should always work with WalkingGraphRewriter and the
# ignore_newtrees flag set to False. Behavior with ignore_newtrees
# = True or with other NodeProcessingGraphRewriters may differ.
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
......
......@@ -20,7 +20,7 @@ from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import (
SequentialNodeRewriter,
TopoOptimizer,
WalkingGraphRewriter,
check_stack_trace,
in2out,
out2in,
......@@ -190,7 +190,7 @@ class TestGreedyDistribute:
e = (a / z + b / x) * x * z
g = FunctionGraph([a, b, c, d, x, y, z], [e])
mul_canonizer.optimize(g)
TopoOptimizer(
WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g)
assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))"
......@@ -199,7 +199,7 @@ class TestGreedyDistribute:
e = (a / x + b) * x
g = FunctionGraph([a, b, x], [e])
mul_canonizer.optimize(g)
TopoOptimizer(
WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g)
assert str(pprint(g.outputs[0])) == "(a + (b * x))"
......@@ -3052,7 +3052,7 @@ class TestLocalErfc:
for inputs, no_match in no_matches:
fg = FunctionGraph(inputs, [no_match], clone=False)
TopoOptimizer(
WalkingGraphRewriter(
SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论