提交 3ce33a3d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename TopoOptimizer to WalkingGraphRewriter

上级 fbe8c876
...@@ -1997,8 +1997,8 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1997,8 +1997,8 @@ class NodeProcessingGraphRewriter(GraphRewriter):
) )
class TopoOptimizer(NodeProcessingGraphRewriter): class WalkingGraphRewriter(NodeProcessingGraphRewriter):
"""An optimizer that applies a single `NodeRewriter` to each node in topological order (or reverse).""" """A rewriter that applies a single `NodeRewriter` to each node in topological order (or reverse)."""
def __init__( def __init__(
self, self,
...@@ -2057,11 +2057,11 @@ class TopoOptimizer(NodeProcessingGraphRewriter): ...@@ -2057,11 +2057,11 @@ class TopoOptimizer(NodeProcessingGraphRewriter):
self.node_rewriter, self.node_rewriter,
) )
@staticmethod @classmethod
def print_profile(stream, prof, level=0): def print_profile(cls, stream, prof, level=0):
blanc = " " * level blanc = " " * level
if prof is None: # Happen as merge_profile() isn't implemented if prof is None: # Happen as merge_profile() isn't implemented
print(blanc, "TopoOptimizer merge_profile not implemented", file=stream) print(blanc, f"{cls.__name__} merge_profile not implemented", file=stream)
return return
( (
...@@ -2077,7 +2077,7 @@ class TopoOptimizer(NodeProcessingGraphRewriter): ...@@ -2077,7 +2077,7 @@ class TopoOptimizer(NodeProcessingGraphRewriter):
print( print(
blanc, blanc,
"TopoOptimizer ", f"{cls.__name__} ",
getattr(opt, "name", getattr(opt, "__name__", "")), getattr(opt, "name", getattr(opt, "__name__", "")),
file=stream, file=stream,
) )
...@@ -2106,19 +2106,19 @@ class TopoOptimizer(NodeProcessingGraphRewriter): ...@@ -2106,19 +2106,19 @@ class TopoOptimizer(NodeProcessingGraphRewriter):
) )
def __str__(self): def __str__(self):
return getattr(self, "__name__", "<TopoOptimizer instance>") return getattr(self, "__name__", super().__str__())
def topogroup_optimizer( def topogroup_optimizer(
order, order,
*node_rewriters, *node_rewriters,
name=None, name=None,
failure_callback=TopoOptimizer.warn_inplace, failure_callback=WalkingGraphRewriter.warn_inplace,
**kwargs, **kwargs,
): ):
r"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph. r"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph.
This constructs `TopoOptimizer`\s, and uses a `SequentialNodeRewriter` when there's This constructs `WalkingGraphRewriter`\s, and uses a `SequentialNodeRewriter` when there's
more than one entry in `node_rewriters`. more than one entry in `node_rewriters`.
""" """
if len(node_rewriters) > 1: if len(node_rewriters) > 1:
...@@ -2128,7 +2128,7 @@ def topogroup_optimizer( ...@@ -2128,7 +2128,7 @@ def topogroup_optimizer(
(node_rewriters,) = node_rewriters (node_rewriters,) = node_rewriters
if not name: if not name:
name = node_rewriters.__name__ name = node_rewriters.__name__
ret = TopoOptimizer( ret = WalkingGraphRewriter(
node_rewriters, node_rewriters,
order=order, order=order,
failure_callback=failure_callback, failure_callback=failure_callback,
...@@ -3220,6 +3220,11 @@ DEPRECATED_NAMES = [ ...@@ -3220,6 +3220,11 @@ DEPRECATED_NAMES = [
"`NavigatorOptimizer` is deprecated: use `NodeProcessingGraphRewriter` instead.", "`NavigatorOptimizer` is deprecated: use `NodeProcessingGraphRewriter` instead.",
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
), ),
(
"TopoOptimizer",
"`TopoOptimizer` is deprecated: use `WalkingGraphRewriter` instead.",
WalkingGraphRewriter,
),
] ]
......
...@@ -483,7 +483,7 @@ class LocalGroupDB(SequenceDB): ...@@ -483,7 +483,7 @@ class LocalGroupDB(SequenceDB):
class TopoDB(OptimizationDatabase): class TopoDB(OptimizationDatabase):
"""Generate a `GraphRewriter` of type TopoOptimizer.""" """Generate a `GraphRewriter` of type `WalkingGraphRewriter`."""
def __init__( def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
...@@ -495,7 +495,7 @@ class TopoDB(OptimizationDatabase): ...@@ -495,7 +495,7 @@ class TopoDB(OptimizationDatabase):
self.failure_callback = failure_callback self.failure_callback = failure_callback
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
return aesara_opt.TopoOptimizer( return aesara_opt.WalkingGraphRewriter(
self.db.query(*tags, **kwtags), self.db.query(*tags, **kwtags),
self.order, self.order,
self.ignore_newtrees, self.ignore_newtrees,
......
...@@ -4,7 +4,7 @@ import aesara ...@@ -4,7 +4,7 @@ import aesara
import aesara.scalar as aes import aesara.scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.opt import PatternNodeRewriter, TopoOptimizer, node_rewriter from aesara.graph.opt import PatternNodeRewriter, WalkingGraphRewriter, node_rewriter
from aesara.link.c.op import COp, _NoPythonCOp from aesara.link.c.op import COp, _NoPythonCOp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.sparse import basic as sparse from aesara.sparse import basic as sparse
...@@ -68,7 +68,9 @@ def local_inplace_remove0(fgraph, node): ...@@ -68,7 +68,9 @@ def local_inplace_remove0(fgraph, node):
aesara.compile.optdb.register( aesara.compile.optdb.register(
"local_inplace_remove0", "local_inplace_remove0",
TopoOptimizer(local_inplace_remove0, failure_callback=TopoOptimizer.warn_inplace), WalkingGraphRewriter(
local_inplace_remove0, failure_callback=WalkingGraphRewriter.warn_inplace
),
"fast_run", "fast_run",
"inplace", "inplace",
position=60, position=60,
...@@ -207,8 +209,8 @@ def local_inplace_addsd_ccode(fgraph, node): ...@@ -207,8 +209,8 @@ def local_inplace_addsd_ccode(fgraph, node):
aesara.compile.optdb.register( aesara.compile.optdb.register(
"local_inplace_addsd_ccode", "local_inplace_addsd_ccode",
TopoOptimizer( WalkingGraphRewriter(
local_inplace_addsd_ccode, failure_callback=TopoOptimizer.warn_inplace local_inplace_addsd_ccode, failure_callback=WalkingGraphRewriter.warn_inplace
), ),
"fast_run", "fast_run",
"inplace", "inplace",
...@@ -240,7 +242,7 @@ def local_addsd_ccode(fgraph, node): ...@@ -240,7 +242,7 @@ def local_addsd_ccode(fgraph, node):
aesara.compile.optdb.register( aesara.compile.optdb.register(
"local_addsd_ccode", "local_addsd_ccode",
TopoOptimizer(local_addsd_ccode), WalkingGraphRewriter(local_addsd_ccode),
# Must be after local_inplace_addsd_ccode at 60 # Must be after local_inplace_addsd_ccode at 60
"fast_run", "fast_run",
position=61, position=61,
......
...@@ -620,7 +620,7 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -620,7 +620,7 @@ class AlgebraicCanonizer(NodeRewriter):
r"""A `Rewriter` that rewrites algebraic expressions. r"""A `Rewriter` that rewrites algebraic expressions.
The variable is a `node_rewriter`. It is best used The variable is a `node_rewriter`. It is best used
with a `TopoOptimizer` in in-to-out order. with a `WalkingGraphRewriter` in in-to-out order.
Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)`` Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)``
......
...@@ -3,7 +3,7 @@ from aesara import tensor as at ...@@ -3,7 +3,7 @@ from aesara import tensor as at
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, node_rewriter from aesara.graph.opt import WalkingGraphRewriter, copy_stack_trace, node_rewriter
def get_diagonal_subtensor_view(x, i0, i1): def get_diagonal_subtensor_view(x, i0, i1):
...@@ -312,8 +312,9 @@ def local_inplace_DiagonalSubtensor(fgraph, node): ...@@ -312,8 +312,9 @@ def local_inplace_DiagonalSubtensor(fgraph, node):
aesara.compile.optdb.register( aesara.compile.optdb.register(
"local_inplace_DiagonalSubtensor", "local_inplace_DiagonalSubtensor",
TopoOptimizer( WalkingGraphRewriter(
local_inplace_DiagonalSubtensor, failure_callback=TopoOptimizer.warn_inplace local_inplace_DiagonalSubtensor,
failure_callback=WalkingGraphRewriter.warn_inplace,
), ),
"fast_run", "fast_run",
"inplace", "inplace",
......
...@@ -8,7 +8,7 @@ from aesara.compile import optdb ...@@ -8,7 +8,7 @@ from aesara.compile import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.opt import ( from aesara.graph.opt import (
MetaNodeRewriterSkip, MetaNodeRewriterSkip,
TopoOptimizer, WalkingGraphRewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
node_rewriter, node_rewriter,
...@@ -51,8 +51,9 @@ def local_inplace_sparse_block_gemv(fgraph, node): ...@@ -51,8 +51,9 @@ def local_inplace_sparse_block_gemv(fgraph, node):
compile.optdb.register( compile.optdb.register(
"local_inplace_sparse_block_gemv", "local_inplace_sparse_block_gemv",
TopoOptimizer( WalkingGraphRewriter(
local_inplace_sparse_block_gemv, failure_callback=TopoOptimizer.warn_inplace local_inplace_sparse_block_gemv,
failure_callback=WalkingGraphRewriter.warn_inplace,
), ),
"fast_run", "fast_run",
"inplace", "inplace",
...@@ -74,9 +75,9 @@ def local_inplace_sparse_block_outer(fgraph, node): ...@@ -74,9 +75,9 @@ def local_inplace_sparse_block_outer(fgraph, node):
compile.optdb.register( compile.optdb.register(
"local_inplace_sparse_block_outer", "local_inplace_sparse_block_outer",
TopoOptimizer( WalkingGraphRewriter(
local_inplace_sparse_block_outer, local_inplace_sparse_block_outer,
failure_callback=TopoOptimizer.warn_inplace, failure_callback=WalkingGraphRewriter.warn_inplace,
), ),
"fast_run", "fast_run",
"inplace", "inplace",
......
...@@ -7,7 +7,12 @@ import aesara ...@@ -7,7 +7,12 @@ import aesara
import aesara.scalar.basic as aes import aesara.scalar.basic as aes
from aesara import compile from aesara import compile
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, node_rewriter from aesara.graph.opt import (
WalkingGraphRewriter,
copy_stack_trace,
in2out,
node_rewriter,
)
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
...@@ -1200,9 +1205,9 @@ def local_IncSubtensor_serialize(fgraph, node): ...@@ -1200,9 +1205,9 @@ def local_IncSubtensor_serialize(fgraph, node):
# print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs] # print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
# We register it in a TopoOptimizer inside the canonizer EQ optimizer. # We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer.
# Otherwise in some cases it was making the EQ optimizer use 45. In # Otherwise in some cases it was making the EQ optimizer use 45. In
# the TopoOptimizer, the EQ only use 5 passes. # the WalkingGraphRewriter, the EQ only use 5 passes.
compile.optdb.register( compile.optdb.register(
"pre_local_IncSubtensor_serialize", "pre_local_IncSubtensor_serialize",
in2out(local_IncSubtensor_serialize), in2out(local_IncSubtensor_serialize),
...@@ -1240,8 +1245,8 @@ def local_inplace_setsubtensor(fgraph, node): ...@@ -1240,8 +1245,8 @@ def local_inplace_setsubtensor(fgraph, node):
compile.optdb.register( compile.optdb.register(
"local_inplace_setsubtensor", "local_inplace_setsubtensor",
TopoOptimizer( WalkingGraphRewriter(
local_inplace_setsubtensor, failure_callback=TopoOptimizer.warn_inplace local_inplace_setsubtensor, failure_callback=WalkingGraphRewriter.warn_inplace
), ),
"fast_run", "fast_run",
"inplace", "inplace",
...@@ -1261,8 +1266,9 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node): ...@@ -1261,8 +1266,9 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
compile.optdb.register( compile.optdb.register(
"local_inplace_AdvancedIncSubtensor1", "local_inplace_AdvancedIncSubtensor1",
TopoOptimizer( WalkingGraphRewriter(
local_inplace_AdvancedIncSubtensor1, failure_callback=TopoOptimizer.warn_inplace local_inplace_AdvancedIncSubtensor1,
failure_callback=WalkingGraphRewriter.warn_inplace,
), ),
"fast_run", "fast_run",
"inplace", "inplace",
...@@ -1286,8 +1292,9 @@ def local_inplace_AdvancedIncSubtensor(fgraph, node): ...@@ -1286,8 +1292,9 @@ def local_inplace_AdvancedIncSubtensor(fgraph, node):
compile.optdb.register( compile.optdb.register(
"local_inplace_AdvancedIncSubtensor", "local_inplace_AdvancedIncSubtensor",
TopoOptimizer( WalkingGraphRewriter(
local_inplace_AdvancedIncSubtensor, failure_callback=TopoOptimizer.warn_inplace local_inplace_AdvancedIncSubtensor,
failure_callback=WalkingGraphRewriter.warn_inplace,
), ),
"fast_run", "fast_run",
"inplace", "inplace",
......
from aesara.compile import optdb from aesara.compile import optdb
from aesara.graph.opt import TopoOptimizer, node_rewriter from aesara.graph.opt import WalkingGraphRewriter, node_rewriter
from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse
...@@ -18,7 +18,9 @@ def typed_list_inplace_opt(fgraph, node): ...@@ -18,7 +18,9 @@ def typed_list_inplace_opt(fgraph, node):
optdb.register( optdb.register(
"typed_list_inplace_opt", "typed_list_inplace_opt",
TopoOptimizer(typed_list_inplace_opt, failure_callback=TopoOptimizer.warn_inplace), WalkingGraphRewriter(
typed_list_inplace_opt, failure_callback=WalkingGraphRewriter.warn_inplace
),
"fast_run", "fast_run",
"inplace", "inplace",
position=60, position=60,
......
...@@ -264,9 +264,9 @@ subset of them) and applies one or several local optimizers. ...@@ -264,9 +264,9 @@ subset of them) and applies one or several local optimizers.
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a]) >>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e >>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify = aesara.graph.opt.TopoOptimizer(local_simplify) >>> simplify = aesara.graph.opt.WalkingGraphRewriter(local_simplify)
>>> simplify.optimize(e) >>> simplify.optimize(e)
(<aesara.graph.opt.TopoOptimizer object at 0x...>, 1, 5, 3, ..., ..., ...) (<aesara.graph.opt.WalkingGraphRewriter object at 0x...>, 1, 5, 3, ..., ..., ...)
>>> e >>> e
FunctionGraph(add(z, mul(x, true_div(z, x)))) FunctionGraph(add(z, mul(x, true_div(z, x))))
...@@ -962,14 +962,14 @@ This will output something like this: ...@@ -962,14 +962,14 @@ This will output something like this:
time_toposort 0.00311398506165 time_toposort 0.00311398506165
validate_time 4.60147857666e-05 validate_time 4.60147857666e-05
callback_time 0.00174236297607 callback_time 0.00174236297607
0.004569s - ('local_dot_to_dot22', 'TopoOptimizer', 0) - 0.000s 0.004569s - ('local_dot_to_dot22', 'WalkingGraphRewriter', 0) - 0.000s
TopoOptimizer WalkingGraphRewriter
nb_node (start, end, changed) (81, 81, 5) nb_node (start, end, changed) (81, 81, 5)
init io_toposort 0.00139284133911 init io_toposort 0.00139284133911
loop time 0.00312399864197 loop time 0.00312399864197
callback_time 0.00172805786133 callback_time 0.00172805786133
0.002283s - ('local_dot22_to_dot22scalar', 'TopoOptimizer', 2) - 0.000s 0.002283s - ('local_dot22_to_dot22scalar', 'WalkingGraphRewriter', 2) - 0.000s
TopoOptimizer WalkingGraphRewriter
nb_node (start, end, changed) (80, 80, 0) nb_node (start, end, changed) (80, 80, 0)
init io_toposort 0.00171804428101 init io_toposort 0.00171804428101
loop time 0.000502109527588 loop time 0.000502109527588
...@@ -982,14 +982,14 @@ This will output something like this: ...@@ -982,14 +982,14 @@ This will output something like this:
time in local optimizers 0.000s time in local optimizers 0.000s
time in global optimizers 0.000s time in global optimizers 0.000s
0 - 0.002s 0 (0.000s in global opts, 0.001s io_toposort) - 80 nodes - 0 - 0.002s 0 (0.000s in global opts, 0.001s io_toposort) - 80 nodes -
0.002227s - ('use_c_blas', 'TopoOptimizer', 4) - 0.000s 0.002227s - ('use_c_blas', 'WalkingGraphRewriter', 4) - 0.000s
TopoOptimizer WalkingGraphRewriter
nb_node (start, end, changed) (80, 80, 0) nb_node (start, end, changed) (80, 80, 0)
init io_toposort 0.0014750957489 init io_toposort 0.0014750957489
loop time 0.00068998336792 loop time 0.00068998336792
callback_time 0.0 callback_time 0.0
0.001632s - ('use_scipy_ger', 'TopoOptimizer', 5) - 0.000s 0.001632s - ('use_scipy_ger', 'WalkingGraphRewriter', 5) - 0.000s
TopoOptimizer WalkingGraphRewriter
nb_node (start, end, changed) (80, 80, 0) nb_node (start, end, changed) (80, 80, 0)
init io_toposort 0.00138401985168 init io_toposort 0.00138401985168
loop time 0.000202178955078 loop time 0.000202178955078
......
...@@ -13,7 +13,7 @@ from aesara.graph.opt import ( ...@@ -13,7 +13,7 @@ from aesara.graph.opt import (
OpKeyOptimizer, OpKeyOptimizer,
PatternNodeRewriter, PatternNodeRewriter,
SubstitutionNodeRewriter, SubstitutionNodeRewriter,
TopoOptimizer, WalkingGraphRewriter,
) )
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.graph.utils import InconsistencyError from aesara.graph.utils import InconsistencyError
...@@ -27,7 +27,7 @@ def PatternOptimizer(p1, p2, ign=True): ...@@ -27,7 +27,7 @@ def PatternOptimizer(p1, p2, ign=True):
def TopoSubstitutionNodeRewriter( def TopoSubstitutionNodeRewriter(
op1, op2, fail=NodeProcessingGraphRewriter.warn_ignore, ign=True op1, op2, fail=NodeProcessingGraphRewriter.warn_ignore, ign=True
): ):
return TopoOptimizer( return WalkingGraphRewriter(
SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail
) )
......
...@@ -13,7 +13,7 @@ from aesara.graph.opt import ( ...@@ -13,7 +13,7 @@ from aesara.graph.opt import (
PatternNodeRewriter, PatternNodeRewriter,
SequentialNodeRewriter, SequentialNodeRewriter,
SubstitutionNodeRewriter, SubstitutionNodeRewriter,
TopoOptimizer, WalkingGraphRewriter,
in2out, in2out,
logging, logging,
node_rewriter, node_rewriter,
...@@ -55,7 +55,7 @@ def PatternOptimizer(p1, p2, ign=False): ...@@ -55,7 +55,7 @@ def PatternOptimizer(p1, p2, ign=False):
def TopoPatternOptimizer(p1, p2, ign=True): def TopoPatternOptimizer(p1, p2, ign=True):
return TopoOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestPatternOptimizer: class TestPatternOptimizer:
...@@ -148,7 +148,7 @@ class TestPatternOptimizer: ...@@ -148,7 +148,7 @@ class TestPatternOptimizer:
assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))" assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
def test_ambiguous(self): def test_ambiguous(self):
# this test should always work with TopoOptimizer and the # this test should always work with WalkingGraphRewriter and the
# ignore_newtrees flag set to False. Behavior with ignore_newtrees # ignore_newtrees flag set to False. Behavior with ignore_newtrees
# = True or with other NodeProcessingGraphRewriters may differ. # = True or with other NodeProcessingGraphRewriters may differ.
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
......
...@@ -20,7 +20,7 @@ from aesara.graph.basic import Apply, Constant, equal_computations ...@@ -20,7 +20,7 @@ from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import ( from aesara.graph.opt import (
SequentialNodeRewriter, SequentialNodeRewriter,
TopoOptimizer, WalkingGraphRewriter,
check_stack_trace, check_stack_trace,
in2out, in2out,
out2in, out2in,
...@@ -190,7 +190,7 @@ class TestGreedyDistribute: ...@@ -190,7 +190,7 @@ class TestGreedyDistribute:
e = (a / z + b / x) * x * z e = (a / z + b / x) * x * z
g = FunctionGraph([a, b, c, d, x, y, z], [e]) g = FunctionGraph([a, b, c, d, x, y, z], [e])
mul_canonizer.optimize(g) mul_canonizer.optimize(g)
TopoOptimizer( WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in" SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g) ).optimize(g)
assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))" assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))"
...@@ -199,7 +199,7 @@ class TestGreedyDistribute: ...@@ -199,7 +199,7 @@ class TestGreedyDistribute:
e = (a / x + b) * x e = (a / x + b) * x
g = FunctionGraph([a, b, x], [e]) g = FunctionGraph([a, b, x], [e])
mul_canonizer.optimize(g) mul_canonizer.optimize(g)
TopoOptimizer( WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in" SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g) ).optimize(g)
assert str(pprint(g.outputs[0])) == "(a + (b * x))" assert str(pprint(g.outputs[0])) == "(a + (b * x))"
...@@ -3052,7 +3052,7 @@ class TestLocalErfc: ...@@ -3052,7 +3052,7 @@ class TestLocalErfc:
for inputs, no_match in no_matches: for inputs, no_match in no_matches:
fg = FunctionGraph(inputs, [no_match], clone=False) fg = FunctionGraph(inputs, [no_match], clone=False)
TopoOptimizer( WalkingGraphRewriter(
SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in" SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg) ).optimize(fg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论