提交 45f48ae6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename GraphRewriter.optimize to GraphRewriter.rewrite

上级 8db93fe1
...@@ -92,16 +92,16 @@ class Rewriter(abc.ABC): ...@@ -92,16 +92,16 @@ class Rewriter(abc.ABC):
class GraphRewriter(Rewriter): class GraphRewriter(Rewriter):
"""A optimizer that can be applied to a `FunctionGraph` in order to transform it. """A rewriter that can be applied to a `FunctionGraph` in order to transform it.
It can represent an optimization or, in general, any kind of transformation This class represents a generalized rewrite that includes the way a graph
one could apply to a `FunctionGraph`. is traversed and/or changed as a whole.
""" """
@abc.abstractmethod @abc.abstractmethod
def apply(self, fgraph): def apply(self, fgraph):
"""Apply the optimization to a `FunctionGraph`. """Apply the rewriter to a `FunctionGraph`.
It may use all the methods defined by the `FunctionGraph`. If the It may use all the methods defined by the `FunctionGraph`. If the
`GraphRewriter` needs to use a certain tool, such as an `GraphRewriter` needs to use a certain tool, such as an
...@@ -110,26 +110,29 @@ class GraphRewriter(Rewriter): ...@@ -110,26 +110,29 @@ class GraphRewriter(Rewriter):
""" """
raise NotImplementedError() raise NotImplementedError()
def optimize(self, fgraph, *args, **kwargs): def optimize(self, *args, **kwargs):
warnings.warn(
"`GraphRewriter.optimize` is deprecated; use `GraphRewriter.rewrite` instead.",
DeprecationWarning,
stacklevel=2,
)
self.rewrite(*args, **kwargs)
def rewrite(self, fgraph, *args, **kwargs):
""" """
This is meant as a shortcut for the following:: This is meant as a shortcut for the following::
opt.add_requirements(fgraph) self.add_requirements(fgraph)
opt.apply(fgraph) self.apply(fgraph)
""" """
self.add_requirements(fgraph) self.add_requirements(fgraph)
ret = self.apply(fgraph, *args, **kwargs) return self.apply(fgraph, *args, **kwargs)
return ret
def __call__(self, fgraph): def __call__(self, fgraph):
"""Optimize a `FunctionGraph`. """Rewrite a `FunctionGraph`."""
return self.rewrite(fgraph)
This is the same as ``self.optimize(fgraph)``.
"""
return self.optimize(fgraph)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
... ...
...@@ -141,12 +144,12 @@ class GraphRewriter(Rewriter): ...@@ -141,12 +144,12 @@ class GraphRewriter(Rewriter):
file=stream, file=stream,
) )
@staticmethod @classmethod
def print_profile(stream, prof, level=0): def print_profile(cls, stream, prof, level=0):
if prof is not None: if prof is not None:
raise NotImplementedError( raise NotImplementedError(
"The function print_profile must be overridden if the" "The function `print_profile` must be overridden when the"
" optimizer return profiling information." " rewriter returns profiling information."
) )
......
...@@ -44,10 +44,10 @@ def optimize_graph( ...@@ -44,10 +44,10 @@ def optimize_graph(
return_only_out = True return_only_out = True
canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs)) canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph) _ = canonicalize_opt.rewrite(fgraph)
if custom_opt: if custom_opt:
custom_opt.optimize(fgraph) custom_opt.rewrite(fgraph)
if return_only_out: if return_only_out:
return fgraph.outputs[0] return fgraph.outputs[0]
...@@ -79,7 +79,7 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -79,7 +79,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
for to_replace, replace_by in givens.items(): for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by) fgraph.replace(to_replace, replace_by)
# Perform merge optimization. # Perform merge optimization.
MergeOptimizer().optimize(fgraph) MergeOptimizer().rewrite(fgraph)
# When two variables perform the same computations, they will have the same # When two variables perform the same computations, they will have the same
# owner in the optimized graph. # owner in the optimized graph.
# We need to be careful with the special case where the owner is None, # We need to be careful with the special case where the owner is None,
......
...@@ -4152,7 +4152,7 @@ class Composite(ScalarOp): ...@@ -4152,7 +4152,7 @@ class Composite(ScalarOp):
# the fgraph to be set to the variable as we need to pickle # the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work. # them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs) fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().optimize(fgraph) MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp): if not isinstance(node.op, ScalarOp):
raise ValueError( raise ValueError(
......
...@@ -56,7 +56,7 @@ Graph Rewriting ...@@ -56,7 +56,7 @@ Graph Rewriting
<libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed <libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed
for the :meth:`GraphRewriter.apply` method to do its job properly. for the :meth:`GraphRewriter.apply` method to do its job properly.
.. method:: optimize(fgraph) .. method:: rewrite(fgraph)
This is the interface function called by Aesara. It calls This is the interface function called by Aesara. It calls
:meth:`GraphRewriter.apply` by default. :meth:`GraphRewriter.apply` by default.
...@@ -159,7 +159,7 @@ Now, we test the optimization: ...@@ -159,7 +159,7 @@ Now, we test the optimization:
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a]) >>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e >>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify.optimize(e) >>> simplify.rewrite(e)
>>> e >>> e
FunctionGraph(add(z, mul(x, true_div(z, x)))) FunctionGraph(add(z, mul(x, true_div(z, x))))
...@@ -175,7 +175,7 @@ optimization you wrote. For example, consider the following: ...@@ -175,7 +175,7 @@ optimization you wrote. For example, consider the following:
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a]) >>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e >>> e
FunctionGraph(true_div(mul(add(y, z), x), add(y, z))) FunctionGraph(true_div(mul(add(y, z), x), add(y, z)))
>>> simplify.optimize(e) >>> simplify.rewrite(e)
>>> e >>> e
FunctionGraph(true_div(mul(add(y, z), x), add(y, z))) FunctionGraph(true_div(mul(add(y, z), x), add(y, z)))
...@@ -186,11 +186,11 @@ computation, using the :class:`MergeOptimizer` defined in ...@@ -186,11 +186,11 @@ computation, using the :class:`MergeOptimizer` defined in
:mod:`aesara.graph.opt`. :mod:`aesara.graph.opt`.
>>> from aesara.graph.opt import MergeOptimizer >>> from aesara.graph.opt import MergeOptimizer
>>> MergeOptimizer().optimize(e) # doctest: +ELLIPSIS >>> MergeOptimizer().rewrite(e) # doctest: +ELLIPSIS
(0, ..., None, None, {}, 1, 0) (0, ..., None, None, {}, 1, 0)
>>> e >>> e
FunctionGraph(true_div(mul(*1 -> add(y, z), x), *1)) FunctionGraph(true_div(mul(*1 -> add(y, z), x), *1))
>>> simplify.optimize(e) >>> simplify.rewrite(e)
>>> e >>> e
FunctionGraph(x) FunctionGraph(x)
...@@ -265,7 +265,7 @@ subset of them) and applies one or several local optimizers. ...@@ -265,7 +265,7 @@ subset of them) and applies one or several local optimizers.
>>> e >>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify = aesara.graph.opt.WalkingGraphRewriter(local_simplify) >>> simplify = aesara.graph.opt.WalkingGraphRewriter(local_simplify)
>>> simplify.optimize(e) >>> simplify.rewrite(e)
(<aesara.graph.opt.WalkingGraphRewriter object at 0x...>, 1, 5, 3, ..., ..., ...) (<aesara.graph.opt.WalkingGraphRewriter object at 0x...>, 1, 5, 3, ..., ..., ...)
>>> e >>> e
FunctionGraph(add(z, mul(x, true_div(z, x)))) FunctionGraph(add(z, mul(x, true_div(z, x))))
......
...@@ -151,7 +151,7 @@ def test_misc(): ...@@ -151,7 +151,7 @@ def test_misc():
e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
PatternOptimizer((transpose_view, (transpose_view, "x")), "x").optimize(g) PatternOptimizer((transpose_view, (transpose_view, "x")), "x").rewrite(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
new_e = add(x, y) new_e = add(x, y)
g.replace_validate(x, new_e) g.replace_validate(x, new_e)
...@@ -330,7 +330,7 @@ def test_long_destroyers_loop(): ...@@ -330,7 +330,7 @@ def test_long_destroyers_loop():
e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x)) e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x))
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
TopoSubstitutionNodeRewriter(add, add_in_place).optimize(g) TopoSubstitutionNodeRewriter(add, add_in_place).rewrite(g)
assert g.consistent() assert g.consistent()
# we don't want to see that! # we don't want to see that!
assert ( assert (
...@@ -366,7 +366,7 @@ def test_multi_destroyers_through_views(): ...@@ -366,7 +366,7 @@ def test_multi_destroyers_through_views():
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(add, add_in_place, fail).optimize(g) TopoSubstitutionNodeRewriter(add, add_in_place, fail).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once assert fail.failures == 1 # should have succeeded once and failed once
...@@ -388,7 +388,7 @@ def test_usage_loop(): ...@@ -388,7 +388,7 @@ def test_usage_loop():
g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False) g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent() assert not g.consistent()
# replace add_in_place with add # replace add_in_place with add
TopoSubstitutionNodeRewriter(add_in_place, add).optimize(g) TopoSubstitutionNodeRewriter(add_in_place, add).rewrite(g)
assert g.consistent() assert g.consistent()
...@@ -409,7 +409,7 @@ def test_usage_loop_insert_views(): ...@@ -409,7 +409,7 @@ def test_usage_loop_insert_views():
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).optimize(g) TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).rewrite(g)
assert g.consistent() assert g.consistent()
# it must keep one sigmoid in the long sigmoid chain # it must keep one sigmoid in the long sigmoid chain
assert fail.failures == 1 assert fail.failures == 1
...@@ -454,19 +454,19 @@ def test_multiple_inplace(): ...@@ -454,19 +454,19 @@ def test_multiple_inplace():
# try to work in-place on x/0 and y/1 (this should fail) # try to work in-place on x/0 and y/1 (this should fail)
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
# try to work in-place on x/0 (this should fail) # try to work in-place on x/0 (this should fail)
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
# try to work in-place on y/1 (this should succeed) # try to work in-place on y/1 (this should succeed)
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 0 assert fail.failures == 0
...@@ -474,6 +474,6 @@ def test_multiple_inplace(): ...@@ -474,6 +474,6 @@ def test_multiple_inplace():
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter( TopoSubstitutionNodeRewriter(
multiple_in_place_1, multiple_in_place_0_1, fail multiple_in_place_1, multiple_in_place_0_1, fail
).optimize(g) ).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
...@@ -64,7 +64,7 @@ class TestPatternOptimizer: ...@@ -64,7 +64,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).optimize(g) PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).rewrite(g)
assert str(g) == "FunctionGraph(Op4(z, y))" assert str(g) == "FunctionGraph(Op4(z, y))"
def test_nested_out_pattern(self): def test_nested_out_pattern(self):
...@@ -73,7 +73,7 @@ class TestPatternOptimizer: ...@@ -73,7 +73,7 @@ class TestPatternOptimizer:
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer( PatternOptimizer(
(op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2")) (op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2"))
).optimize(g) ).rewrite(g)
assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))" assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
def test_unification_1(self): def test_unification_1(self):
...@@ -83,7 +83,7 @@ class TestPatternOptimizer: ...@@ -83,7 +83,7 @@ class TestPatternOptimizer:
PatternOptimizer( PatternOptimizer(
(op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op1, (op2, "1", "1"), "2"), # they are the same in the pattern
(op4, "2", "1"), (op4, "2", "1"),
).optimize(g) ).rewrite(g)
# So the replacement should occur # So the replacement should occur
assert str(g) == "FunctionGraph(Op4(z, x))" assert str(g) == "FunctionGraph(Op4(z, x))"
...@@ -94,7 +94,7 @@ class TestPatternOptimizer: ...@@ -94,7 +94,7 @@ class TestPatternOptimizer:
PatternOptimizer( PatternOptimizer(
(op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op1, (op2, "1", "1"), "2"), # they are the same in the pattern
(op4, "2", "1"), (op4, "2", "1"),
).optimize(g) ).rewrite(g)
# The replacement should NOT occur # The replacement should NOT occur
assert str(g) == "FunctionGraph(Op1(Op2(x, y), z))" assert str(g) == "FunctionGraph(Op1(Op2(x, y), z))"
...@@ -103,7 +103,7 @@ class TestPatternOptimizer: ...@@ -103,7 +103,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).optimize(g) PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))" assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
def test_no_recurse(self): def test_no_recurse(self):
...@@ -113,7 +113,7 @@ class TestPatternOptimizer: ...@@ -113,7 +113,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).optimize(g) PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))" assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))"
def test_multiple(self): def test_multiple(self):
...@@ -121,30 +121,30 @@ class TestPatternOptimizer: ...@@ -121,30 +121,30 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), op2(x, y), op2(y, z)) e = op1(op2(x, y), op2(x, y), op2(y, z))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op4, "1")).optimize(g) PatternOptimizer((op2, "1", "2"), (op4, "1")).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))" assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def test_nested_even(self): def test_nested_even(self):
# regardless of the order in which we optimize, this # regardless of the order in which we rewrite, this
# should work # should work
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(x)))) e = op1(op1(op1(op1(x))))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, "1")), "1").optimize(g) PatternOptimizer((op1, (op1, "1")), "1").rewrite(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
def test_nested_odd(self): def test_nested_odd(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, "1")), "1").optimize(g) PatternOptimizer((op1, (op1, "1")), "1").rewrite(g)
assert str(g) == "FunctionGraph(Op1(x))" assert str(g) == "FunctionGraph(Op1(x))"
def test_expand(self): def test_expand(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(x))) e = op1(op1(op1(x)))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, "1"), (op2, (op1, "1")), ign=True).optimize(g) PatternOptimizer((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))" assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
def test_ambiguous(self): def test_ambiguous(self):
...@@ -154,7 +154,7 @@ class TestPatternOptimizer: ...@@ -154,7 +154,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g) TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).rewrite(g)
assert str(g) == "FunctionGraph(Op1(x))" assert str(g) == "FunctionGraph(Op1(x))"
def test_constant(self): def test_constant(self):
...@@ -163,7 +163,7 @@ class TestPatternOptimizer: ...@@ -163,7 +163,7 @@ class TestPatternOptimizer:
z = Constant(MyType(), 2, name="z") z = Constant(MyType(), 2, name="z")
e = op1(op1(x, y), y) e = op1(op1(x, y), y)
g = FunctionGraph([y], [e]) g = FunctionGraph([y], [e])
PatternOptimizer((op1, z, "1"), (op2, "1", z)).optimize(g) PatternOptimizer((op1, z, "1"), (op2, "1", z)).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))" assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))"
def test_constraints(self): def test_constraints(self):
...@@ -177,14 +177,14 @@ class TestPatternOptimizer: ...@@ -177,14 +177,14 @@ class TestPatternOptimizer:
PatternOptimizer( PatternOptimizer(
(op1, {"pattern": "1", "constraint": constraint}), (op3, "1") (op1, {"pattern": "1", "constraint": constraint}), (op3, "1")
).optimize(g) ).rewrite(g)
assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))" assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
def test_match_same(self): def test_match_same(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(x, x) e = op1(x, x)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).optimize(g) PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).rewrite(g)
assert str(g) == "FunctionGraph(Op3(x, x))" assert str(g) == "FunctionGraph(Op3(x, x))"
@pytest.mark.xfail( @pytest.mark.xfail(
...@@ -201,7 +201,7 @@ class TestPatternOptimizer: ...@@ -201,7 +201,7 @@ class TestPatternOptimizer:
PatternOptimizer( PatternOptimizer(
{"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y") {"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y")
).optimize(g) ).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))" assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
def test_allow_multiple_clients(self): def test_allow_multiple_clients(self):
...@@ -210,7 +210,7 @@ class TestPatternOptimizer: ...@@ -210,7 +210,7 @@ class TestPatternOptimizer:
# `e0` has multiple clients (i.e. the `op4` and `op3` nodes) # `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
e = op3(op4(e0), e0) e = op3(op4(e0), e0)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g) PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).rewrite(g)
assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))" assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
def test_eq(self): def test_eq(self):
...@@ -218,7 +218,7 @@ class TestPatternOptimizer: ...@@ -218,7 +218,7 @@ class TestPatternOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op_y(x, y), z) e = op1(op_y(x, y), z)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).optimize(g) PatternOptimizer((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).rewrite(g)
str_g = str(g) str_g = str(g)
assert str_g == "FunctionGraph(Op4(z, y))" assert str_g == "FunctionGraph(Op4(z, y))"
...@@ -232,14 +232,14 @@ class TestSubstitutionNodeRewriter: ...@@ -232,14 +232,14 @@ class TestSubstitutionNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
KeyedSubstitutionNodeRewriter(op1, op2).optimize(g) KeyedSubstitutionNodeRewriter(op1, op2).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))" assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
def test_straightforward_2(self): def test_straightforward_2(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x), op3(y), op4(z)) e = op1(op2(x), op3(y), op4(z))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
KeyedSubstitutionNodeRewriter(op3, op4).optimize(g) KeyedSubstitutionNodeRewriter(op3, op4).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))" assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
...@@ -261,7 +261,7 @@ class TestMergeOptimizer: ...@@ -261,7 +261,7 @@ class TestMergeOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e], clone=False) g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
out_var = g.outputs[0] out_var = g.outputs[0]
var_1, var_2, var_3 = out_var.owner.inputs var_1, var_2, var_3 = out_var.owner.inputs
assert var_1 is var_2 assert var_1 is var_2
...@@ -273,7 +273,7 @@ class TestMergeOptimizer: ...@@ -273,7 +273,7 @@ class TestMergeOptimizer:
z = Constant(MyType(), 2, name="z") z = Constant(MyType(), 2, name="z")
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e], clone=False) g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
out_var = g.outputs[0] out_var = g.outputs[0]
var_1, var_2, var_3 = out_var.owner.inputs var_1, var_2, var_3 = out_var.owner.inputs
assert var_1 is var_2 assert var_1 is var_2
...@@ -283,7 +283,7 @@ class TestMergeOptimizer: ...@@ -283,7 +283,7 @@ class TestMergeOptimizer:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z))) e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = FunctionGraph([x, y, z], [e], clone=False) g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
out_var = g.outputs[0] out_var = g.outputs[0]
var_1, var_2 = out_var.owner.inputs var_1, var_2 = out_var.owner.inputs
assert var_2.owner.inputs[0] is var_1 assert var_2.owner.inputs[0] is var_1
...@@ -293,14 +293,14 @@ class TestMergeOptimizer: ...@@ -293,14 +293,14 @@ class TestMergeOptimizer:
e = op1(op3(op2(x, y)), op3(op2(y, x))) e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
g.attach_feature(AssertNoChanges()) g.attach_feature(AssertNoChanges())
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
def test_merge_outputs(self): def test_merge_outputs(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e1 = op3(op2(x, y)) e1 = op3(op2(x, y))
e2 = op3(op2(x, y)) e2 = op3(op2(x, y))
g = FunctionGraph([x, y, z], [e1, e2], clone=False) g = FunctionGraph([x, y, z], [e1, e2], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert g.outputs[0] is g.outputs[1] assert g.outputs[0] is g.outputs[1]
def test_identical_constant_args(self): def test_identical_constant_args(self):
...@@ -309,7 +309,7 @@ class TestMergeOptimizer: ...@@ -309,7 +309,7 @@ class TestMergeOptimizer:
z = Constant(MyType(), 2, name="z") z = Constant(MyType(), 2, name="z")
e1 = op1(y, z) e1 = op1(y, z)
g = FunctionGraph([x, y, z], [e1], clone=False) g = FunctionGraph([x, y, z], [e1], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == op1 assert g.outputs[0].owner.op == op1
input_1 = g.outputs[0].owner.inputs[0] input_1 = g.outputs[0].owner.inputs[0]
...@@ -322,7 +322,7 @@ class TestMergeOptimizer: ...@@ -322,7 +322,7 @@ class TestMergeOptimizer:
x2 = matrix("x2") x2 = matrix("x2")
e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2) e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
g = FunctionGraph([x1, x2], [e], clone=False) g = FunctionGraph([x1, x2], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs add_inputs = g.outputs[0].owner.inputs
...@@ -342,7 +342,7 @@ class TestMergeOptimizer: ...@@ -342,7 +342,7 @@ class TestMergeOptimizer:
assert_op(x1, (x1 > x2).all()), x2 assert_op(x1, (x1 > x2).all()), x2
) )
g = FunctionGraph([x1, x2], [e], clone=False) g = FunctionGraph([x1, x2], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs add_inputs = g.outputs[0].owner.inputs
...@@ -365,7 +365,7 @@ class TestMergeOptimizer: ...@@ -365,7 +365,7 @@ class TestMergeOptimizer:
assert_op(x1, (x1 > x2).all()), x2 assert_op(x1, (x1 > x2).all()), x2
) )
g = FunctionGraph([x1, x2, x3], [e], clone=False) g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs add_inputs = g.outputs[0].owner.inputs
...@@ -387,7 +387,7 @@ class TestMergeOptimizer: ...@@ -387,7 +387,7 @@ class TestMergeOptimizer:
x1, assert_op(x2, (x2 > x3).all()) x1, assert_op(x2, (x2 > x3).all())
) )
g = FunctionGraph([x1, x2, x3], [e], clone=False) g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs add_inputs = g.outputs[0].owner.inputs
...@@ -411,7 +411,7 @@ class TestMergeOptimizer: ...@@ -411,7 +411,7 @@ class TestMergeOptimizer:
assert_op(x1, (x1 > x3).all()), x2 assert_op(x1, (x1 > x3).all()), x2
) )
g = FunctionGraph([x1, x2, x3], [e], clone=False) g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert g.outputs[0].owner.op == add assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs add_inputs = g.outputs[0].owner.inputs
...@@ -432,7 +432,7 @@ class TestMergeOptimizer: ...@@ -432,7 +432,7 @@ class TestMergeOptimizer:
z = NoInputOp(param=1)() z = NoInputOp(param=1)()
fg = FunctionGraph([], [x, y, z], clone=False) fg = FunctionGraph([], [x, y, z], clone=False)
MergeOptimizer().optimize(fg) MergeOptimizer().rewrite(fg)
assert fg.outputs[0] is fg.outputs[1] assert fg.outputs[0] is fg.outputs[1]
assert fg.outputs[0] is not fg.outputs[2] assert fg.outputs[0] is not fg.outputs[2]
...@@ -454,7 +454,7 @@ class TestEquilibrium: ...@@ -454,7 +454,7 @@ class TestEquilibrium:
], ],
max_use_ratio=10, max_use_ratio=10,
) )
opt.optimize(g) opt.rewrite(g)
# print g # print g
assert str(g) == "FunctionGraph(Op2(x, y))" assert str(g) == "FunctionGraph(Op2(x, y))"
...@@ -473,7 +473,7 @@ class TestEquilibrium: ...@@ -473,7 +473,7 @@ class TestEquilibrium:
], ],
max_use_ratio=10, max_use_ratio=10,
) )
opt.optimize(g) opt.rewrite(g)
assert str(g) == "FunctionGraph(Op2(x, y))" assert str(g) == "FunctionGraph(Op2(x, y))"
@config.change_flags(on_opt_error="ignore") @config.change_flags(on_opt_error="ignore")
...@@ -496,7 +496,7 @@ class TestEquilibrium: ...@@ -496,7 +496,7 @@ class TestEquilibrium:
], ],
max_use_ratio=1.0 / len(g.apply_nodes), max_use_ratio=1.0 / len(g.apply_nodes),
) # each opt can only be applied once ) # each opt can only be applied once
opt.optimize(g) opt.rewrite(g)
finally: finally:
_logger.setLevel(oldlevel) _logger.setLevel(oldlevel)
# print 'after', g # print 'after', g
...@@ -612,7 +612,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks): ...@@ -612,7 +612,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
], ],
max_use_ratio=1, max_use_ratio=1,
) )
opt.optimize(fg) opt.rewrite(fg)
output = fg.outputs[0] output = fg.outputs[0]
if isinstance(out_pattern, tuple): if isinstance(out_pattern, tuple):
assert output.owner.op == op2 assert output.owner.op == op2
...@@ -642,7 +642,7 @@ def test_patternsub_invalid_dtype(out_pattern): ...@@ -642,7 +642,7 @@ def test_patternsub_invalid_dtype(out_pattern):
], ],
max_use_ratio=1, max_use_ratio=1,
) )
opt.optimize(fg) opt.rewrite(fg)
assert e.type.is_super(fg.outputs[0].type) assert e.type.is_super(fg.outputs[0].type)
...@@ -660,7 +660,7 @@ def test_patternsub_different_output_lengths(): ...@@ -660,7 +660,7 @@ def test_patternsub_different_output_lengths():
o = op1(e1) o = op1(e1)
fgraph = FunctionGraph(inputs=[x], outputs=[o]) fgraph = FunctionGraph(inputs=[x], outputs=[o])
opt.optimize(fgraph) opt.rewrite(fgraph)
assert fgraph.outputs[0].owner.op == op1 assert fgraph.outputs[0].owner.op == op1
......
...@@ -824,7 +824,7 @@ class TestScan: ...@@ -824,7 +824,7 @@ class TestScan:
assert scan_c is not scan_a assert scan_c is not scan_a
g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False) g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
scan_a_out, scan_b_out, scan_c_out = g.outputs scan_a_out, scan_b_out, scan_c_out = g.outputs
assert scan_a_out is scan_b_out assert scan_a_out is scan_b_out
......
...@@ -340,7 +340,7 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -340,7 +340,7 @@ class TestLogSoftmax(utt.InferShapeTester):
) )
fgraph = FunctionGraph([x], [new_g]) fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()] assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
...@@ -647,7 +647,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -647,7 +647,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)]) fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_w_bias(self): def test_softmax_optimizations_w_bias(self):
...@@ -659,7 +659,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -659,7 +659,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)]) fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 1 assert len(fgraph.toposort()) == 1
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
...@@ -676,7 +676,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -676,7 +676,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
) )
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 2
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
...@@ -694,7 +694,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -694,7 +694,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy], ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy],
) )
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = {node.op for node in fgraph.toposort()} ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias not in ops assert crossentropy_softmax_argmax_1hot_with_bias not in ops
...@@ -717,7 +717,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -717,7 +717,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions: for expr in expressions:
fgraph = FunctionGraph([x, y], [expr]) fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4 assert len(ops) == 4
...@@ -726,7 +726,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -726,7 +726,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Also verify the gradient wrt x # Also verify the gradient wrt x
fgraph = FunctionGraph([x, y], [grad(expr, x)]) fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 assert len(ops) == 2
...@@ -744,14 +744,14 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -744,14 +744,14 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in bias_expressions: for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr, x]) fgraph = FunctionGraph([x, b, y], [expr, x])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 # [big_op, sum] assert len(ops) == 2 # [big_op, sum]
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 assert len(ops) == 2
...@@ -770,7 +770,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -770,7 +770,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in mean_expressions: for expr in mean_expressions:
fgraph = FunctionGraph([x, y], [expr]) fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 6 assert len(ops) == 6
...@@ -778,7 +778,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -778,7 +778,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, y], [grad(expr, x)]) fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5 assert len(ops) == 5
...@@ -798,7 +798,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -798,7 +798,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in mean_bias_expressions: for expr in mean_bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr]) fgraph = FunctionGraph([x, b, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4 assert len(ops) == 4
...@@ -806,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -806,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5 assert len(ops) == 5
...@@ -827,7 +827,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -827,7 +827,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions: for expr in expressions:
fgraph = FunctionGraph([x, y], [expr]) fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5 assert len(ops) == 5
...@@ -836,7 +836,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -836,7 +836,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Also verify the gradient wrt x # Also verify the gradient wrt x
fgraph = FunctionGraph([x, y], [grad(expr, x)]) fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 3 assert len(ops) == 3
...@@ -888,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -888,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions: for expr in expressions:
fgraph = FunctionGraph([x, y, a], [expr]) fgraph = FunctionGraph([x, y, a], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 5 <= len(fgraph.toposort()) <= 10 assert 5 <= len(fgraph.toposort()) <= 10
...@@ -898,7 +898,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -898,7 +898,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Verify the gradient wrt x # Verify the gradient wrt x
fgraph = FunctionGraph([x, y, a], [grad(expr, x)]) fgraph = FunctionGraph([x, y, a], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 3 <= len(fgraph.toposort()) <= 6 assert 3 <= len(fgraph.toposort()) <= 6
...@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph( fgraph = FunctionGraph(
[x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})] [x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})]
) )
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 6 <= len(fgraph.toposort()) <= 8 assert 6 <= len(fgraph.toposort()) <= 8
...@@ -927,7 +927,7 @@ def test_argmax_pushdown(): ...@@ -927,7 +927,7 @@ def test_argmax_pushdown():
# test that the max_and_argmax is pushed down if the max is not used # test that the max_and_argmax is pushed down if the max is not used
out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1] out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1]
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
...@@ -942,7 +942,7 @@ def test_argmax_pushdown(): ...@@ -942,7 +942,7 @@ def test_argmax_pushdown():
assert hasattr(fgraph.outputs[0].tag, "trace") assert hasattr(fgraph.outputs[0].tag, "trace")
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
...@@ -963,7 +963,7 @@ def test_argmax_pushdown_bias(): ...@@ -963,7 +963,7 @@ def test_argmax_pushdown_bias():
out = argmax(softmax_with_bias(x, b), axis=-1) out = argmax(softmax_with_bias(x, b), axis=-1)
fgraph = FunctionGraph([x, b], [out]) fgraph = FunctionGraph([x, b], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
types_to_check = (DimShuffle, Elemwise, Argmax) types_to_check = (DimShuffle, Elemwise, Argmax)
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 3
...@@ -977,7 +977,7 @@ def test_argmax_pushdown_bias(): ...@@ -977,7 +977,7 @@ def test_argmax_pushdown_bias():
out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0] out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
fgraph = FunctionGraph([x, b], [out]) fgraph = FunctionGraph([x, b], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
......
...@@ -159,11 +159,11 @@ def ds(x, y): ...@@ -159,11 +159,11 @@ def ds(x, y):
def optimize(g, level="fast_run"): def optimize(g, level="fast_run"):
if level == "fast_run": if level == "fast_run":
_optimizer_fast_run.optimize(g) _optimizer_fast_run.rewrite(g)
elif level == "specialize": elif level == "specialize":
_optimizer_specialize.optimize(g) _optimizer_specialize.rewrite(g)
elif level == "stabilize": elif level == "stabilize":
_optimizer_stabilize.optimize(g) _optimizer_stabilize.rewrite(g)
else: else:
raise ValueError(level) raise ValueError(level)
return g return g
...@@ -184,7 +184,7 @@ class TestDimshuffleLift: ...@@ -184,7 +184,7 @@ class TestDimshuffleLift:
assert ( assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))" str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
) )
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty # no need to check_stack_trace as graph is supposed to be empty
...@@ -196,7 +196,7 @@ class TestDimshuffleLift: ...@@ -196,7 +196,7 @@ class TestDimshuffleLift:
str(g) str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))" == "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g) ), str(g)
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g) assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
...@@ -209,7 +209,7 @@ class TestDimshuffleLift: ...@@ -209,7 +209,7 @@ class TestDimshuffleLift:
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}" "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))" "(InplaceDimShuffle{0,x,1}(x))))"
), str(g) ), str(g)
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)", str(g) assert str(g) == "FunctionGraph(x)", str(g)
# no need to check_stack_trace as graph is supposed to be empty # no need to check_stack_trace as graph is supposed to be empty
...@@ -238,7 +238,7 @@ class TestDimshuffleLift: ...@@ -238,7 +238,7 @@ class TestDimshuffleLift:
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}" "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))" "(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
) )
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) in (opt_str_g_inplace, opt_str_g_noinplace), str(g) assert str(g) in (opt_str_g_inplace, opt_str_g_noinplace), str(g)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
...@@ -277,7 +277,7 @@ class TestDimshuffleLift: ...@@ -277,7 +277,7 @@ class TestDimshuffleLift:
e = ds(x, (0, 1)) e = ds(x, (0, 1))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e])
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))" assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert hasattr(g.outputs[0].tag, "trace") assert hasattr(g.outputs[0].tag, "trace")
...@@ -294,7 +294,7 @@ class TestDimshuffleLift: ...@@ -294,7 +294,7 @@ class TestDimshuffleLift:
str(g) str(g)
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" == "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
) )
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert ( assert (
str(g) str(g)
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" == "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
...@@ -331,7 +331,7 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -331,7 +331,7 @@ def test_local_useless_dimshuffle_in_reshape():
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))" "Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
) )
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.optimize(g) useless_dimshuffle_in_reshape.rewrite(g)
assert str(g) == ( assert str(g) == (
"FunctionGraph(Reshape{1}(vector, Shape(vector)), " "FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), " "Reshape{2}(mat, Shape(mat)), "
...@@ -347,7 +347,7 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -347,7 +347,7 @@ def test_local_useless_dimshuffle_in_reshape():
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2]) h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h) str_h = str(h)
useless_dimshuffle_in_reshape.optimize(h) useless_dimshuffle_in_reshape.rewrite(h)
assert str(h) == str_h assert str(h) == str_h
...@@ -1505,7 +1505,7 @@ class TestLocalCanonicalizeAlloc: ...@@ -1505,7 +1505,7 @@ class TestLocalCanonicalizeAlloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort()) assert any(isinstance(node.op, Alloc) for node in g.toposort())
alloc_lift = out2in(local_alloc_sink_dimshuffle) alloc_lift = out2in(local_alloc_sink_dimshuffle)
alloc_lift.optimize(g) alloc_lift.rewrite(g)
if has_alloc: if has_alloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort()) assert any(isinstance(node.op, Alloc) for node in g.toposort())
...@@ -2849,8 +2849,8 @@ class TestLocalReshapeToDimshuffle: ...@@ -2849,8 +2849,8 @@ class TestLocalReshapeToDimshuffle:
"TensorConstant{[1 5 1 6 1 1]}))" "TensorConstant{[1 5 1 6 1 1]}))"
) )
reshape_lift.optimize(g) reshape_lift.rewrite(g)
useless_reshape.optimize(g) useless_reshape.rewrite(g)
assert str(g) == ( assert str(g) == (
"FunctionGraph(InplaceDimShuffle{x,0}" "FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, (None,))>), " "(<TensorType(float64, (None,))>), "
...@@ -2880,9 +2880,9 @@ def test_local_reshape_lift(): ...@@ -2880,9 +2880,9 @@ def test_local_reshape_lift():
class TestLiftTransposeThroughDot: class TestLiftTransposeThroughDot:
def simple_optimize(self, g): def simple_optimize(self, g):
out2in(local_useless_elemwise).optimize(g) out2in(local_useless_elemwise).rewrite(g)
out2in(local_lift_transpose_through_dot).optimize(g) out2in(local_lift_transpose_through_dot).rewrite(g)
out2in(local_useless_elemwise).optimize(g) out2in(local_useless_elemwise).rewrite(g)
return g return g
def test_matrix_matrix(self): def test_matrix_matrix(self):
...@@ -3159,9 +3159,9 @@ def test_local_useless_alloc(): ...@@ -3159,9 +3159,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w) output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output]) g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g) useless_alloc.rewrite(g)
merge_alloc.optimize(g) merge_alloc.rewrite(g)
useless_alloc.optimize(g) useless_alloc.rewrite(g)
topo = g.toposort() topo = g.toposort()
assert len(topo) == 1 assert len(topo) == 1
...@@ -3172,9 +3172,9 @@ def test_local_useless_alloc(): ...@@ -3172,9 +3172,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w) output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output]) g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g) useless_alloc.rewrite(g)
merge_alloc.optimize(g) merge_alloc.rewrite(g)
useless_alloc.optimize(g) useless_alloc.rewrite(g)
topo = g.toposort() topo = g.toposort()
assert len(topo) == 1 assert len(topo) == 1
...@@ -3186,9 +3186,9 @@ def test_local_useless_alloc(): ...@@ -3186,9 +3186,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w) output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w)
g = FunctionGraph([m, x, y, y2, z, w], [output]) g = FunctionGraph([m, x, y, y2, z, w], [output])
useless_alloc.optimize(g) useless_alloc.rewrite(g)
merge_alloc.optimize(g) merge_alloc.rewrite(g)
useless_alloc.optimize(g) useless_alloc.rewrite(g)
topo = g.toposort() topo = g.toposort()
assert len(topo) == 3 assert len(topo) == 3
......
...@@ -150,11 +150,11 @@ def ds(x, y): ...@@ -150,11 +150,11 @@ def ds(x, y):
def optimize(g, level="fast_run"): def optimize(g, level="fast_run"):
if level == "fast_run": if level == "fast_run":
_optimizer_fast_run.optimize(g) _optimizer_fast_run.rewrite(g)
elif level == "specialize": elif level == "specialize":
_optimizer_specialize.optimize(g) _optimizer_specialize.rewrite(g)
elif level == "stabilize": elif level == "stabilize":
_optimizer_stabilize.optimize(g) _optimizer_stabilize.rewrite(g)
else: else:
raise ValueError(level) raise ValueError(level)
return g return g
...@@ -189,19 +189,19 @@ class TestGreedyDistribute: ...@@ -189,19 +189,19 @@ class TestGreedyDistribute:
# 1. ((a/x + b/y) * x * y) --> a*y + b*x # 1. ((a/x + b/y) * x * y) --> a*y + b*x
e = (a / z + b / x) * x * z e = (a / z + b / x) * x * z
g = FunctionGraph([a, b, c, d, x, y, z], [e]) g = FunctionGraph([a, b, c, d, x, y, z], [e])
mul_canonizer.optimize(g) mul_canonizer.rewrite(g)
WalkingGraphRewriter( WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in" SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g) ).rewrite(g)
assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))" assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))"
# 2. ((a/x + b) * x) --> a + b*x # 2. ((a/x + b) * x) --> a + b*x
e = (a / x + b) * x e = (a / x + b) * x
g = FunctionGraph([a, b, x], [e]) g = FunctionGraph([a, b, x], [e])
mul_canonizer.optimize(g) mul_canonizer.rewrite(g)
WalkingGraphRewriter( WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in" SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g) ).rewrite(g)
assert str(pprint(g.outputs[0])) == "(a + (b * x))" assert str(pprint(g.outputs[0])) == "(a + (b * x))"
def test_kording_bug(self): def test_kording_bug(self):
...@@ -3054,7 +3054,7 @@ class TestLocalErfc: ...@@ -3054,7 +3054,7 @@ class TestLocalErfc:
WalkingGraphRewriter( WalkingGraphRewriter(
SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in" SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg) ).rewrite(fg)
# Make sure that the graph hasn't been changed # Make sure that the graph hasn't been changed
assert fg.outputs[0] is no_match assert fg.outputs[0] is no_match
......
...@@ -72,7 +72,7 @@ def test_merge_with_weird_eq(): ...@@ -72,7 +72,7 @@ def test_merge_with_weird_eq():
x = at.constant(np.asarray(1), name="x") x = at.constant(np.asarray(1), name="x")
y = at.constant(np.asarray(1), name="y") y = at.constant(np.asarray(1), name="y")
g = FunctionGraph([x, y], [x + y]) g = FunctionGraph([x, y], [x + y])
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1 assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0] node = list(g.apply_nodes)[0]
...@@ -84,7 +84,7 @@ def test_merge_with_weird_eq(): ...@@ -84,7 +84,7 @@ def test_merge_with_weird_eq():
x = at.constant(np.ones(5), name="x") x = at.constant(np.ones(5), name="x")
y = at.constant(np.ones(5), name="y") y = at.constant(np.ones(5), name="y")
g = FunctionGraph([x, y], [x + y]) g = FunctionGraph([x, y], [x + y])
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1 assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0] node = list(g.apply_nodes)[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论