提交 45f48ae6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename GraphRewriter.optimize to GraphRewriter.rewrite

上级 8db93fe1
......@@ -92,16 +92,16 @@ class Rewriter(abc.ABC):
class GraphRewriter(Rewriter):
"""A optimizer that can be applied to a `FunctionGraph` in order to transform it.
"""A rewriter that can be applied to a `FunctionGraph` in order to transform it.
It can represent an optimization or, in general, any kind of transformation
one could apply to a `FunctionGraph`.
This class represents a generalized rewrite that includes the way a graph
is traversed and/or changed as a whole.
"""
@abc.abstractmethod
def apply(self, fgraph):
"""Apply the optimization to a `FunctionGraph`.
"""Apply the rewriter to a `FunctionGraph`.
It may use all the methods defined by the `FunctionGraph`. If the
`GraphRewriter` needs to use a certain tool, such as an
......@@ -110,26 +110,29 @@ class GraphRewriter(Rewriter):
"""
raise NotImplementedError()
def optimize(self, fgraph, *args, **kwargs):
def optimize(self, *args, **kwargs):
warnings.warn(
"`GraphRewriter.optimize` is deprecated; use `GraphRewriter.rewrite` instead.",
DeprecationWarning,
stacklevel=2,
)
self.rewrite(*args, **kwargs)
def rewrite(self, fgraph, *args, **kwargs):
"""
This is meant as a shortcut for the following::
opt.add_requirements(fgraph)
opt.apply(fgraph)
self.add_requirements(fgraph)
self.apply(fgraph)
"""
self.add_requirements(fgraph)
ret = self.apply(fgraph, *args, **kwargs)
return ret
return self.apply(fgraph, *args, **kwargs)
def __call__(self, fgraph):
"""Optimize a `FunctionGraph`.
This is the same as ``self.optimize(fgraph)``.
"""
return self.optimize(fgraph)
"""Rewrite a `FunctionGraph`."""
return self.rewrite(fgraph)
def add_requirements(self, fgraph):
...
......@@ -141,12 +144,12 @@ class GraphRewriter(Rewriter):
file=stream,
)
@staticmethod
def print_profile(stream, prof, level=0):
@classmethod
def print_profile(cls, stream, prof, level=0):
if prof is not None:
raise NotImplementedError(
"The function print_profile must be overridden if the"
" optimizer return profiling information."
"The function `print_profile` must be overridden when the"
" rewriter returns profiling information."
)
......
......@@ -44,10 +44,10 @@ def optimize_graph(
return_only_out = True
canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph)
_ = canonicalize_opt.rewrite(fgraph)
if custom_opt:
custom_opt.optimize(fgraph)
custom_opt.rewrite(fgraph)
if return_only_out:
return fgraph.outputs[0]
......@@ -79,7 +79,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
# Perform merge optimization.
MergeOptimizer().optimize(fgraph)
MergeOptimizer().rewrite(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# We need to be careful with the special case where the owner is None,
......
......@@ -4152,7 +4152,7 @@ class Composite(ScalarOp):
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().optimize(fgraph)
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise ValueError(
......
......@@ -56,7 +56,7 @@ Graph Rewriting
<libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed
for the :meth:`GraphRewriter.apply` method to do its job properly.
.. method:: optimize(fgraph)
.. method:: rewrite(fgraph)
This is the interface function called by Aesara. It calls
:meth:`GraphRewriter.apply` by default.
......@@ -159,7 +159,7 @@ Now, we test the optimization:
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify.optimize(e)
>>> simplify.rewrite(e)
>>> e
FunctionGraph(add(z, mul(x, true_div(z, x))))
......@@ -175,7 +175,7 @@ optimization you wrote. For example, consider the following:
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e
FunctionGraph(true_div(mul(add(y, z), x), add(y, z)))
>>> simplify.optimize(e)
>>> simplify.rewrite(e)
>>> e
FunctionGraph(true_div(mul(add(y, z), x), add(y, z)))
......@@ -186,11 +186,11 @@ computation, using the :class:`MergeOptimizer` defined in
:mod:`aesara.graph.opt`.
>>> from aesara.graph.opt import MergeOptimizer
>>> MergeOptimizer().optimize(e) # doctest: +ELLIPSIS
>>> MergeOptimizer().rewrite(e) # doctest: +ELLIPSIS
(0, ..., None, None, {}, 1, 0)
>>> e
FunctionGraph(true_div(mul(*1 -> add(y, z), x), *1))
>>> simplify.optimize(e)
>>> simplify.rewrite(e)
>>> e
FunctionGraph(x)
......@@ -265,7 +265,7 @@ subset of them) and applies one or several local optimizers.
>>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify = aesara.graph.opt.WalkingGraphRewriter(local_simplify)
>>> simplify.optimize(e)
>>> simplify.rewrite(e)
(<aesara.graph.opt.WalkingGraphRewriter object at 0x...>, 1, 5, 3, ..., ..., ...)
>>> e
FunctionGraph(add(z, mul(x, true_div(z, x))))
......
......@@ -151,7 +151,7 @@ def test_misc():
e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = create_fgraph([x, y, z], [e])
assert g.consistent()
PatternOptimizer((transpose_view, (transpose_view, "x")), "x").optimize(g)
PatternOptimizer((transpose_view, (transpose_view, "x")), "x").rewrite(g)
assert str(g) == "FunctionGraph(x)"
new_e = add(x, y)
g.replace_validate(x, new_e)
......@@ -330,7 +330,7 @@ def test_long_destroyers_loop():
e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x))
g = create_fgraph([x, y, z], [e])
assert g.consistent()
TopoSubstitutionNodeRewriter(add, add_in_place).optimize(g)
TopoSubstitutionNodeRewriter(add, add_in_place).rewrite(g)
assert g.consistent()
# we don't want to see that!
assert (
......@@ -366,7 +366,7 @@ def test_multi_destroyers_through_views():
g = create_fgraph([x, y, z], [e])
assert g.consistent()
fail = FailureWatch()
TopoSubstitutionNodeRewriter(add, add_in_place, fail).optimize(g)
TopoSubstitutionNodeRewriter(add, add_in_place, fail).rewrite(g)
assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once
......@@ -388,7 +388,7 @@ def test_usage_loop():
g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent()
# replace add_in_place with add
TopoSubstitutionNodeRewriter(add_in_place, add).optimize(g)
TopoSubstitutionNodeRewriter(add_in_place, add).rewrite(g)
assert g.consistent()
......@@ -409,7 +409,7 @@ def test_usage_loop_insert_views():
g = create_fgraph([x, y, z], [e])
assert g.consistent()
fail = FailureWatch()
TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).optimize(g)
TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).rewrite(g)
assert g.consistent()
# it must keep one sigmoid in the long sigmoid chain
assert fail.failures == 1
......@@ -454,19 +454,19 @@ def test_multiple_inplace():
# try to work in-place on x/0 and y/1 (this should fail)
fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).rewrite(g)
assert g.consistent()
assert fail.failures == 1
# try to work in-place on x/0 (this should fail)
fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).rewrite(g)
assert g.consistent()
assert fail.failures == 1
# try to work in-place on y/1 (this should succeed)
fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).rewrite(g)
assert g.consistent()
assert fail.failures == 0
......@@ -474,6 +474,6 @@ def test_multiple_inplace():
fail = FailureWatch()
TopoSubstitutionNodeRewriter(
multiple_in_place_1, multiple_in_place_0_1, fail
).optimize(g)
).rewrite(g)
assert g.consistent()
assert fail.failures == 1
......@@ -64,7 +64,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).optimize(g)
PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).rewrite(g)
assert str(g) == "FunctionGraph(Op4(z, y))"
def test_nested_out_pattern(self):
......@@ -73,7 +73,7 @@ class TestPatternOptimizer:
g = FunctionGraph([x, y, z], [e])
PatternOptimizer(
(op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2"))
).optimize(g)
).rewrite(g)
assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
def test_unification_1(self):
......@@ -83,7 +83,7 @@ class TestPatternOptimizer:
PatternOptimizer(
(op1, (op2, "1", "1"), "2"), # they are the same in the pattern
(op4, "2", "1"),
).optimize(g)
).rewrite(g)
# So the replacement should occur
assert str(g) == "FunctionGraph(Op4(z, x))"
......@@ -94,7 +94,7 @@ class TestPatternOptimizer:
PatternOptimizer(
(op1, (op2, "1", "1"), "2"), # they are the same in the pattern
(op4, "2", "1"),
).optimize(g)
).rewrite(g)
# The replacement should NOT occur
assert str(g) == "FunctionGraph(Op1(Op2(x, y), z))"
......@@ -103,7 +103,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).optimize(g)
PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
def test_no_recurse(self):
......@@ -113,7 +113,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).optimize(g)
PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))"
def test_multiple(self):
......@@ -121,30 +121,30 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), op2(x, y), op2(y, z))
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op4, "1")).optimize(g)
PatternOptimizer((op2, "1", "2"), (op4, "1")).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def test_nested_even(self):
# regardless of the order in which we optimize, this
# regardless of the order in which we rewrite, this
# should work
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(x))))
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, "1")), "1").optimize(g)
PatternOptimizer((op1, (op1, "1")), "1").rewrite(g)
assert str(g) == "FunctionGraph(x)"
def test_nested_odd(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, "1")), "1").optimize(g)
PatternOptimizer((op1, (op1, "1")), "1").rewrite(g)
assert str(g) == "FunctionGraph(Op1(x))"
def test_expand(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(x)))
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, "1"), (op2, (op1, "1")), ign=True).optimize(g)
PatternOptimizer((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
def test_ambiguous(self):
......@@ -154,7 +154,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g)
TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).rewrite(g)
assert str(g) == "FunctionGraph(Op1(x))"
def test_constant(self):
......@@ -163,7 +163,7 @@ class TestPatternOptimizer:
z = Constant(MyType(), 2, name="z")
e = op1(op1(x, y), y)
g = FunctionGraph([y], [e])
PatternOptimizer((op1, z, "1"), (op2, "1", z)).optimize(g)
PatternOptimizer((op1, z, "1"), (op2, "1", z)).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))"
def test_constraints(self):
......@@ -177,14 +177,14 @@ class TestPatternOptimizer:
PatternOptimizer(
(op1, {"pattern": "1", "constraint": constraint}), (op3, "1")
).optimize(g)
).rewrite(g)
assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
def test_match_same(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(x, x)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).optimize(g)
PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).rewrite(g)
assert str(g) == "FunctionGraph(Op3(x, x))"
@pytest.mark.xfail(
......@@ -201,7 +201,7 @@ class TestPatternOptimizer:
PatternOptimizer(
{"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y")
).optimize(g)
).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
def test_allow_multiple_clients(self):
......@@ -210,7 +210,7 @@ class TestPatternOptimizer:
# `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
e = op3(op4(e0), e0)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g)
PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).rewrite(g)
assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
def test_eq(self):
......@@ -218,7 +218,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op_y(x, y), z)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).optimize(g)
PatternOptimizer((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).rewrite(g)
str_g = str(g)
assert str_g == "FunctionGraph(Op4(z, y))"
......@@ -232,14 +232,14 @@ class TestSubstitutionNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
KeyedSubstitutionNodeRewriter(op1, op2).optimize(g)
KeyedSubstitutionNodeRewriter(op1, op2).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
def test_straightforward_2(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x), op3(y), op4(z))
g = FunctionGraph([x, y, z], [e])
KeyedSubstitutionNodeRewriter(op3, op4).optimize(g)
KeyedSubstitutionNodeRewriter(op3, op4).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
......@@ -261,7 +261,7 @@ class TestMergeOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
out_var = g.outputs[0]
var_1, var_2, var_3 = out_var.owner.inputs
assert var_1 is var_2
......@@ -273,7 +273,7 @@ class TestMergeOptimizer:
z = Constant(MyType(), 2, name="z")
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
out_var = g.outputs[0]
var_1, var_2, var_3 = out_var.owner.inputs
assert var_1 is var_2
......@@ -283,7 +283,7 @@ class TestMergeOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
out_var = g.outputs[0]
var_1, var_2 = out_var.owner.inputs
assert var_2.owner.inputs[0] is var_1
......@@ -293,14 +293,14 @@ class TestMergeOptimizer:
e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = FunctionGraph([x, y, z], [e])
g.attach_feature(AssertNoChanges())
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
def test_merge_outputs(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e1 = op3(op2(x, y))
e2 = op3(op2(x, y))
g = FunctionGraph([x, y, z], [e1, e2], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert g.outputs[0] is g.outputs[1]
def test_identical_constant_args(self):
......@@ -309,7 +309,7 @@ class TestMergeOptimizer:
z = Constant(MyType(), 2, name="z")
e1 = op1(y, z)
g = FunctionGraph([x, y, z], [e1], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == op1
input_1 = g.outputs[0].owner.inputs[0]
......@@ -322,7 +322,7 @@ class TestMergeOptimizer:
x2 = matrix("x2")
e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
g = FunctionGraph([x1, x2], [e], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
......@@ -342,7 +342,7 @@ class TestMergeOptimizer:
assert_op(x1, (x1 > x2).all()), x2
)
g = FunctionGraph([x1, x2], [e], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
......@@ -365,7 +365,7 @@ class TestMergeOptimizer:
assert_op(x1, (x1 > x2).all()), x2
)
g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
......@@ -387,7 +387,7 @@ class TestMergeOptimizer:
x1, assert_op(x2, (x2 > x3).all())
)
g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
......@@ -411,7 +411,7 @@ class TestMergeOptimizer:
assert_op(x1, (x1 > x3).all()), x2
)
g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
......@@ -432,7 +432,7 @@ class TestMergeOptimizer:
z = NoInputOp(param=1)()
fg = FunctionGraph([], [x, y, z], clone=False)
MergeOptimizer().optimize(fg)
MergeOptimizer().rewrite(fg)
assert fg.outputs[0] is fg.outputs[1]
assert fg.outputs[0] is not fg.outputs[2]
......@@ -454,7 +454,7 @@ class TestEquilibrium:
],
max_use_ratio=10,
)
opt.optimize(g)
opt.rewrite(g)
# print g
assert str(g) == "FunctionGraph(Op2(x, y))"
......@@ -473,7 +473,7 @@ class TestEquilibrium:
],
max_use_ratio=10,
)
opt.optimize(g)
opt.rewrite(g)
assert str(g) == "FunctionGraph(Op2(x, y))"
@config.change_flags(on_opt_error="ignore")
......@@ -496,7 +496,7 @@ class TestEquilibrium:
],
max_use_ratio=1.0 / len(g.apply_nodes),
) # each opt can only be applied once
opt.optimize(g)
opt.rewrite(g)
finally:
_logger.setLevel(oldlevel)
# print 'after', g
......@@ -612,7 +612,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
],
max_use_ratio=1,
)
opt.optimize(fg)
opt.rewrite(fg)
output = fg.outputs[0]
if isinstance(out_pattern, tuple):
assert output.owner.op == op2
......@@ -642,7 +642,7 @@ def test_patternsub_invalid_dtype(out_pattern):
],
max_use_ratio=1,
)
opt.optimize(fg)
opt.rewrite(fg)
assert e.type.is_super(fg.outputs[0].type)
......@@ -660,7 +660,7 @@ def test_patternsub_different_output_lengths():
o = op1(e1)
fgraph = FunctionGraph(inputs=[x], outputs=[o])
opt.optimize(fgraph)
opt.rewrite(fgraph)
assert fgraph.outputs[0].owner.op == op1
......
......@@ -824,7 +824,7 @@ class TestScan:
assert scan_c is not scan_a
g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
scan_a_out, scan_b_out, scan_c_out = g.outputs
assert scan_a_out is scan_b_out
......
......@@ -340,7 +340,7 @@ class TestLogSoftmax(utt.InferShapeTester):
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
......@@ -647,7 +647,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_w_bias(self):
......@@ -659,7 +659,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 1
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
......@@ -676,7 +676,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
)
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 2
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
......@@ -694,7 +694,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy],
)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias not in ops
......@@ -717,7 +717,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions:
fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4
......@@ -726,7 +726,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Also verify the gradient wrt x
fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2
......@@ -744,14 +744,14 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr, x])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 # [big_op, sum]
assert crossentropy_softmax_argmax_1hot_with_bias in ops
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2
......@@ -770,7 +770,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in mean_expressions:
fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 6
......@@ -778,7 +778,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5
......@@ -798,7 +798,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in mean_bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4
......@@ -806,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5
......@@ -827,7 +827,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions:
fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5
......@@ -836,7 +836,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Also verify the gradient wrt x
fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 3
......@@ -888,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions:
fgraph = FunctionGraph([x, y, a], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 5 <= len(fgraph.toposort()) <= 10
......@@ -898,7 +898,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Verify the gradient wrt x
fgraph = FunctionGraph([x, y, a], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 3 <= len(fgraph.toposort()) <= 6
......@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph(
[x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})]
)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 6 <= len(fgraph.toposort()) <= 8
......@@ -927,7 +927,7 @@ def test_argmax_pushdown():
# test that the max_and_argmax is pushed down if the max is not used
out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1]
fgraph = FunctionGraph([x], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
# print 'AFTER'
# for node in fgraph.toposort():
......@@ -942,7 +942,7 @@ def test_argmax_pushdown():
assert hasattr(fgraph.outputs[0].tag, "trace")
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
# print 'AFTER'
# for node in fgraph.toposort():
......@@ -963,7 +963,7 @@ def test_argmax_pushdown_bias():
out = argmax(softmax_with_bias(x, b), axis=-1)
fgraph = FunctionGraph([x, b], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
types_to_check = (DimShuffle, Elemwise, Argmax)
assert len(fgraph.toposort()) == 3
......@@ -977,7 +977,7 @@ def test_argmax_pushdown_bias():
out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
fgraph = FunctionGraph([x, b], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
......
......@@ -159,11 +159,11 @@ def ds(x, y):
def optimize(g, level="fast_run"):
if level == "fast_run":
_optimizer_fast_run.optimize(g)
_optimizer_fast_run.rewrite(g)
elif level == "specialize":
_optimizer_specialize.optimize(g)
_optimizer_specialize.rewrite(g)
elif level == "stabilize":
_optimizer_stabilize.optimize(g)
_optimizer_stabilize.rewrite(g)
else:
raise ValueError(level)
return g
......@@ -184,7 +184,7 @@ class TestDimshuffleLift:
assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty
......@@ -196,7 +196,7 @@ class TestDimshuffleLift:
str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
......@@ -209,7 +209,7 @@ class TestDimshuffleLift:
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
), str(g)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)", str(g)
# no need to check_stack_trace as graph is supposed to be empty
......@@ -238,7 +238,7 @@ class TestDimshuffleLift:
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) in (opt_str_g_inplace, opt_str_g_noinplace), str(g)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
......@@ -277,7 +277,7 @@ class TestDimshuffleLift:
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
# Check stacktrace was copied over correctly after opt was applied
assert hasattr(g.outputs[0].tag, "trace")
......@@ -294,7 +294,7 @@ class TestDimshuffleLift:
str(g)
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert (
str(g)
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
......@@ -331,7 +331,7 @@ def test_local_useless_dimshuffle_in_reshape():
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.optimize(g)
useless_dimshuffle_in_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
......@@ -347,7 +347,7 @@ def test_local_useless_dimshuffle_in_reshape():
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h)
useless_dimshuffle_in_reshape.optimize(h)
useless_dimshuffle_in_reshape.rewrite(h)
assert str(h) == str_h
......@@ -1505,7 +1505,7 @@ class TestLocalCanonicalizeAlloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort())
alloc_lift = out2in(local_alloc_sink_dimshuffle)
alloc_lift.optimize(g)
alloc_lift.rewrite(g)
if has_alloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort())
......@@ -2849,8 +2849,8 @@ class TestLocalReshapeToDimshuffle:
"TensorConstant{[1 5 1 6 1 1]}))"
)
reshape_lift.optimize(g)
useless_reshape.optimize(g)
reshape_lift.rewrite(g)
useless_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, (None,))>), "
......@@ -2880,9 +2880,9 @@ def test_local_reshape_lift():
class TestLiftTransposeThroughDot:
def simple_optimize(self, g):
out2in(local_useless_elemwise).optimize(g)
out2in(local_lift_transpose_through_dot).optimize(g)
out2in(local_useless_elemwise).optimize(g)
out2in(local_useless_elemwise).rewrite(g)
out2in(local_lift_transpose_through_dot).rewrite(g)
out2in(local_useless_elemwise).rewrite(g)
return g
def test_matrix_matrix(self):
......@@ -3159,9 +3159,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
useless_alloc.rewrite(g)
merge_alloc.rewrite(g)
useless_alloc.rewrite(g)
topo = g.toposort()
assert len(topo) == 1
......@@ -3172,9 +3172,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
useless_alloc.rewrite(g)
merge_alloc.rewrite(g)
useless_alloc.rewrite(g)
topo = g.toposort()
assert len(topo) == 1
......@@ -3186,9 +3186,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w)
g = FunctionGraph([m, x, y, y2, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
useless_alloc.rewrite(g)
merge_alloc.rewrite(g)
useless_alloc.rewrite(g)
topo = g.toposort()
assert len(topo) == 3
......
......@@ -150,11 +150,11 @@ def ds(x, y):
def optimize(g, level="fast_run"):
if level == "fast_run":
_optimizer_fast_run.optimize(g)
_optimizer_fast_run.rewrite(g)
elif level == "specialize":
_optimizer_specialize.optimize(g)
_optimizer_specialize.rewrite(g)
elif level == "stabilize":
_optimizer_stabilize.optimize(g)
_optimizer_stabilize.rewrite(g)
else:
raise ValueError(level)
return g
......@@ -189,19 +189,19 @@ class TestGreedyDistribute:
# 1. ((a/x + b/y) * x * y) --> a*y + b*x
e = (a / z + b / x) * x * z
g = FunctionGraph([a, b, c, d, x, y, z], [e])
mul_canonizer.optimize(g)
mul_canonizer.rewrite(g)
WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g)
).rewrite(g)
assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))"
# 2. ((a/x + b) * x) --> a + b*x
e = (a / x + b) * x
g = FunctionGraph([a, b, x], [e])
mul_canonizer.optimize(g)
mul_canonizer.rewrite(g)
WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g)
).rewrite(g)
assert str(pprint(g.outputs[0])) == "(a + (b * x))"
def test_kording_bug(self):
......@@ -3054,7 +3054,7 @@ class TestLocalErfc:
WalkingGraphRewriter(
SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg)
).rewrite(fg)
# Make sure that the graph hasn't been changed
assert fg.outputs[0] is no_match
......
......@@ -72,7 +72,7 @@ def test_merge_with_weird_eq():
x = at.constant(np.asarray(1), name="x")
y = at.constant(np.asarray(1), name="y")
g = FunctionGraph([x, y], [x + y])
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0]
......@@ -84,7 +84,7 @@ def test_merge_with_weird_eq():
x = at.constant(np.ones(5), name="x")
y = at.constant(np.ones(5), name="y")
g = FunctionGraph([x, y], [x + y])
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论