提交 45f48ae6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename GraphRewriter.optimize to GraphRewriter.rewrite

上级 8db93fe1
...@@ -92,16 +92,16 @@ class Rewriter(abc.ABC): ...@@ -92,16 +92,16 @@ class Rewriter(abc.ABC):
class GraphRewriter(Rewriter): class GraphRewriter(Rewriter):
"""A optimizer that can be applied to a `FunctionGraph` in order to transform it. """A rewriter that can be applied to a `FunctionGraph` in order to transform it.
It can represent an optimization or, in general, any kind of transformation This class represents a generalized rewrite that includes the way a graph
one could apply to a `FunctionGraph`. is traversed and/or changed as a whole.
""" """
@abc.abstractmethod @abc.abstractmethod
def apply(self, fgraph): def apply(self, fgraph):
"""Apply the optimization to a `FunctionGraph`. """Apply the rewriter to a `FunctionGraph`.
It may use all the methods defined by the `FunctionGraph`. If the It may use all the methods defined by the `FunctionGraph`. If the
`GraphRewriter` needs to use a certain tool, such as an `GraphRewriter` needs to use a certain tool, such as an
...@@ -110,26 +110,29 @@ class GraphRewriter(Rewriter): ...@@ -110,26 +110,29 @@ class GraphRewriter(Rewriter):
""" """
raise NotImplementedError() raise NotImplementedError()
def optimize(self, fgraph, *args, **kwargs): def optimize(self, *args, **kwargs):
warnings.warn(
"`GraphRewriter.optimize` is deprecated; use `GraphRewriter.rewrite` instead.",
DeprecationWarning,
stacklevel=2,
)
self.rewrite(*args, **kwargs)
def rewrite(self, fgraph, *args, **kwargs):
""" """
This is meant as a shortcut for the following:: This is meant as a shortcut for the following::
opt.add_requirements(fgraph) self.add_requirements(fgraph)
opt.apply(fgraph) self.apply(fgraph)
""" """
self.add_requirements(fgraph) self.add_requirements(fgraph)
ret = self.apply(fgraph, *args, **kwargs) return self.apply(fgraph, *args, **kwargs)
return ret
def __call__(self, fgraph): def __call__(self, fgraph):
"""Optimize a `FunctionGraph`. """Rewrite a `FunctionGraph`."""
return self.rewrite(fgraph)
This is the same as ``self.optimize(fgraph)``.
"""
return self.optimize(fgraph)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
... ...
...@@ -141,12 +144,12 @@ class GraphRewriter(Rewriter): ...@@ -141,12 +144,12 @@ class GraphRewriter(Rewriter):
file=stream, file=stream,
) )
@staticmethod @classmethod
def print_profile(stream, prof, level=0): def print_profile(cls, stream, prof, level=0):
if prof is not None: if prof is not None:
raise NotImplementedError( raise NotImplementedError(
"The function print_profile must be overridden if the" "The function `print_profile` must be overridden when the"
" optimizer return profiling information." " rewriter returns profiling information."
) )
......
...@@ -44,10 +44,10 @@ def optimize_graph( ...@@ -44,10 +44,10 @@ def optimize_graph(
return_only_out = True return_only_out = True
canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs)) canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph) _ = canonicalize_opt.rewrite(fgraph)
if custom_opt: if custom_opt:
custom_opt.optimize(fgraph) custom_opt.rewrite(fgraph)
if return_only_out: if return_only_out:
return fgraph.outputs[0] return fgraph.outputs[0]
...@@ -79,7 +79,7 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -79,7 +79,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
for to_replace, replace_by in givens.items(): for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by) fgraph.replace(to_replace, replace_by)
# Perform merge optimization. # Perform merge optimization.
MergeOptimizer().optimize(fgraph) MergeOptimizer().rewrite(fgraph)
# When two variables perform the same computations, they will have the same # When two variables perform the same computations, they will have the same
# owner in the optimized graph. # owner in the optimized graph.
# We need to be careful with the special case where the owner is None, # We need to be careful with the special case where the owner is None,
......
...@@ -4152,7 +4152,7 @@ class Composite(ScalarOp): ...@@ -4152,7 +4152,7 @@ class Composite(ScalarOp):
# the fgraph to be set to the variable as we need to pickle # the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work. # them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs) fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().optimize(fgraph) MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp): if not isinstance(node.op, ScalarOp):
raise ValueError( raise ValueError(
......
...@@ -56,7 +56,7 @@ Graph Rewriting ...@@ -56,7 +56,7 @@ Graph Rewriting
<libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed <libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed
for the :meth:`GraphRewriter.apply` method to do its job properly. for the :meth:`GraphRewriter.apply` method to do its job properly.
.. method:: optimize(fgraph) .. method:: rewrite(fgraph)
This is the interface function called by Aesara. It calls This is the interface function called by Aesara. It calls
:meth:`GraphRewriter.apply` by default. :meth:`GraphRewriter.apply` by default.
...@@ -159,7 +159,7 @@ Now, we test the optimization: ...@@ -159,7 +159,7 @@ Now, we test the optimization:
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a]) >>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e >>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify.optimize(e) >>> simplify.rewrite(e)
>>> e >>> e
FunctionGraph(add(z, mul(x, true_div(z, x)))) FunctionGraph(add(z, mul(x, true_div(z, x))))
...@@ -175,7 +175,7 @@ optimization you wrote. For example, consider the following: ...@@ -175,7 +175,7 @@ optimization you wrote. For example, consider the following:
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a]) >>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e >>> e
FunctionGraph(true_div(mul(add(y, z), x), add(y, z))) FunctionGraph(true_div(mul(add(y, z), x), add(y, z)))
>>> simplify.optimize(e) >>> simplify.rewrite(e)
>>> e >>> e
FunctionGraph(true_div(mul(add(y, z), x), add(y, z))) FunctionGraph(true_div(mul(add(y, z), x), add(y, z)))
...@@ -186,11 +186,11 @@ computation, using the :class:`MergeOptimizer` defined in ...@@ -186,11 +186,11 @@ computation, using the :class:`MergeOptimizer` defined in
:mod:`aesara.graph.opt`. :mod:`aesara.graph.opt`.
>>> from aesara.graph.opt import MergeOptimizer >>> from aesara.graph.opt import MergeOptimizer
>>> MergeOptimizer().optimize(e) # doctest: +ELLIPSIS >>> MergeOptimizer().rewrite(e) # doctest: +ELLIPSIS
(0, ..., None, None, {}, 1, 0) (0, ..., None, None, {}, 1, 0)
>>> e >>> e
FunctionGraph(true_div(mul(*1 -> add(y, z), x), *1)) FunctionGraph(true_div(mul(*1 -> add(y, z), x), *1))
>>> simplify.optimize(e) >>> simplify.rewrite(e)
>>> e >>> e
FunctionGraph(x) FunctionGraph(x)
...@@ -265,7 +265,7 @@ subset of them) and applies one or several local optimizers. ...@@ -265,7 +265,7 @@ subset of them) and applies one or several local optimizers.
>>> e >>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify = aesara.graph.opt.WalkingGraphRewriter(local_simplify) >>> simplify = aesara.graph.opt.WalkingGraphRewriter(local_simplify)
>>> simplify.optimize(e) >>> simplify.rewrite(e)
(<aesara.graph.opt.WalkingGraphRewriter object at 0x...>, 1, 5, 3, ..., ..., ...) (<aesara.graph.opt.WalkingGraphRewriter object at 0x...>, 1, 5, 3, ..., ..., ...)
>>> e >>> e
FunctionGraph(add(z, mul(x, true_div(z, x)))) FunctionGraph(add(z, mul(x, true_div(z, x))))
......
...@@ -151,7 +151,7 @@ def test_misc(): ...@@ -151,7 +151,7 @@ def test_misc():
e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
PatternOptimizer((transpose_view, (transpose_view, "x")), "x").optimize(g) PatternOptimizer((transpose_view, (transpose_view, "x")), "x").rewrite(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
new_e = add(x, y) new_e = add(x, y)
g.replace_validate(x, new_e) g.replace_validate(x, new_e)
...@@ -330,7 +330,7 @@ def test_long_destroyers_loop(): ...@@ -330,7 +330,7 @@ def test_long_destroyers_loop():
e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x)) e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x))
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
TopoSubstitutionNodeRewriter(add, add_in_place).optimize(g) TopoSubstitutionNodeRewriter(add, add_in_place).rewrite(g)
assert g.consistent() assert g.consistent()
# we don't want to see that! # we don't want to see that!
assert ( assert (
...@@ -366,7 +366,7 @@ def test_multi_destroyers_through_views(): ...@@ -366,7 +366,7 @@ def test_multi_destroyers_through_views():
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(add, add_in_place, fail).optimize(g) TopoSubstitutionNodeRewriter(add, add_in_place, fail).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once assert fail.failures == 1 # should have succeeded once and failed once
...@@ -388,7 +388,7 @@ def test_usage_loop(): ...@@ -388,7 +388,7 @@ def test_usage_loop():
g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False) g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent() assert not g.consistent()
# replace add_in_place with add # replace add_in_place with add
TopoSubstitutionNodeRewriter(add_in_place, add).optimize(g) TopoSubstitutionNodeRewriter(add_in_place, add).rewrite(g)
assert g.consistent() assert g.consistent()
...@@ -409,7 +409,7 @@ def test_usage_loop_insert_views(): ...@@ -409,7 +409,7 @@ def test_usage_loop_insert_views():
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).optimize(g) TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).rewrite(g)
assert g.consistent() assert g.consistent()
# it must keep one sigmoid in the long sigmoid chain # it must keep one sigmoid in the long sigmoid chain
assert fail.failures == 1 assert fail.failures == 1
...@@ -454,19 +454,19 @@ def test_multiple_inplace(): ...@@ -454,19 +454,19 @@ def test_multiple_inplace():
# try to work in-place on x/0 and y/1 (this should fail) # try to work in-place on x/0 and y/1 (this should fail)
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
# try to work in-place on x/0 (this should fail) # try to work in-place on x/0 (this should fail)
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
# try to work in-place on y/1 (this should succeed) # try to work in-place on y/1 (this should succeed)
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 0 assert fail.failures == 0
...@@ -474,6 +474,6 @@ def test_multiple_inplace(): ...@@ -474,6 +474,6 @@ def test_multiple_inplace():
fail = FailureWatch() fail = FailureWatch()
TopoSubstitutionNodeRewriter( TopoSubstitutionNodeRewriter(
multiple_in_place_1, multiple_in_place_0_1, fail multiple_in_place_1, multiple_in_place_0_1, fail
).optimize(g) ).rewrite(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
差异被折叠。
...@@ -824,7 +824,7 @@ class TestScan: ...@@ -824,7 +824,7 @@ class TestScan:
assert scan_c is not scan_a assert scan_c is not scan_a
g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False) g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
scan_a_out, scan_b_out, scan_c_out = g.outputs scan_a_out, scan_b_out, scan_c_out = g.outputs
assert scan_a_out is scan_b_out assert scan_a_out is scan_b_out
......
...@@ -340,7 +340,7 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -340,7 +340,7 @@ class TestLogSoftmax(utt.InferShapeTester):
) )
fgraph = FunctionGraph([x], [new_g]) fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()] assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
...@@ -647,7 +647,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -647,7 +647,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)]) fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_w_bias(self): def test_softmax_optimizations_w_bias(self):
...@@ -659,7 +659,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -659,7 +659,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)]) fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 1 assert len(fgraph.toposort()) == 1
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
...@@ -676,7 +676,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -676,7 +676,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
) )
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 2
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
...@@ -694,7 +694,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -694,7 +694,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy], ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy],
) )
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = {node.op for node in fgraph.toposort()} ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias not in ops assert crossentropy_softmax_argmax_1hot_with_bias not in ops
...@@ -717,7 +717,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -717,7 +717,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions: for expr in expressions:
fgraph = FunctionGraph([x, y], [expr]) fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4 assert len(ops) == 4
...@@ -726,7 +726,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -726,7 +726,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Also verify the gradient wrt x # Also verify the gradient wrt x
fgraph = FunctionGraph([x, y], [grad(expr, x)]) fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 assert len(ops) == 2
...@@ -744,14 +744,14 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -744,14 +744,14 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in bias_expressions: for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr, x]) fgraph = FunctionGraph([x, b, y], [expr, x])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 # [big_op, sum] assert len(ops) == 2 # [big_op, sum]
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 assert len(ops) == 2
...@@ -770,7 +770,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -770,7 +770,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in mean_expressions: for expr in mean_expressions:
fgraph = FunctionGraph([x, y], [expr]) fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 6 assert len(ops) == 6
...@@ -778,7 +778,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -778,7 +778,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, y], [grad(expr, x)]) fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5 assert len(ops) == 5
...@@ -798,7 +798,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -798,7 +798,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in mean_bias_expressions: for expr in mean_bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr]) fgraph = FunctionGraph([x, b, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4 assert len(ops) == 4
...@@ -806,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -806,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5 assert len(ops) == 5
...@@ -827,7 +827,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -827,7 +827,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions: for expr in expressions:
fgraph = FunctionGraph([x, y], [expr]) fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5 assert len(ops) == 5
...@@ -836,7 +836,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -836,7 +836,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Also verify the gradient wrt x # Also verify the gradient wrt x
fgraph = FunctionGraph([x, y], [grad(expr, x)]) fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 3 assert len(ops) == 3
...@@ -888,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -888,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions: for expr in expressions:
fgraph = FunctionGraph([x, y, a], [expr]) fgraph = FunctionGraph([x, y, a], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 5 <= len(fgraph.toposort()) <= 10 assert 5 <= len(fgraph.toposort()) <= 10
...@@ -898,7 +898,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -898,7 +898,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Verify the gradient wrt x # Verify the gradient wrt x
fgraph = FunctionGraph([x, y, a], [grad(expr, x)]) fgraph = FunctionGraph([x, y, a], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 3 <= len(fgraph.toposort()) <= 6 assert 3 <= len(fgraph.toposort()) <= 6
...@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph( fgraph = FunctionGraph(
[x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})] [x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})]
) )
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 6 <= len(fgraph.toposort()) <= 8 assert 6 <= len(fgraph.toposort()) <= 8
...@@ -927,7 +927,7 @@ def test_argmax_pushdown(): ...@@ -927,7 +927,7 @@ def test_argmax_pushdown():
# test that the max_and_argmax is pushed down if the max is not used # test that the max_and_argmax is pushed down if the max is not used
out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1] out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1]
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
...@@ -942,7 +942,7 @@ def test_argmax_pushdown(): ...@@ -942,7 +942,7 @@ def test_argmax_pushdown():
assert hasattr(fgraph.outputs[0].tag, "trace") assert hasattr(fgraph.outputs[0].tag, "trace")
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
...@@ -963,7 +963,7 @@ def test_argmax_pushdown_bias(): ...@@ -963,7 +963,7 @@ def test_argmax_pushdown_bias():
out = argmax(softmax_with_bias(x, b), axis=-1) out = argmax(softmax_with_bias(x, b), axis=-1)
fgraph = FunctionGraph([x, b], [out]) fgraph = FunctionGraph([x, b], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
types_to_check = (DimShuffle, Elemwise, Argmax) types_to_check = (DimShuffle, Elemwise, Argmax)
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 3
...@@ -977,7 +977,7 @@ def test_argmax_pushdown_bias(): ...@@ -977,7 +977,7 @@ def test_argmax_pushdown_bias():
out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0] out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
fgraph = FunctionGraph([x, b], [out]) fgraph = FunctionGraph([x, b], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
......
...@@ -159,11 +159,11 @@ def ds(x, y): ...@@ -159,11 +159,11 @@ def ds(x, y):
def optimize(g, level="fast_run"): def optimize(g, level="fast_run"):
if level == "fast_run": if level == "fast_run":
_optimizer_fast_run.optimize(g) _optimizer_fast_run.rewrite(g)
elif level == "specialize": elif level == "specialize":
_optimizer_specialize.optimize(g) _optimizer_specialize.rewrite(g)
elif level == "stabilize": elif level == "stabilize":
_optimizer_stabilize.optimize(g) _optimizer_stabilize.rewrite(g)
else: else:
raise ValueError(level) raise ValueError(level)
return g return g
...@@ -184,7 +184,7 @@ class TestDimshuffleLift: ...@@ -184,7 +184,7 @@ class TestDimshuffleLift:
assert ( assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))" str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
) )
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty # no need to check_stack_trace as graph is supposed to be empty
...@@ -196,7 +196,7 @@ class TestDimshuffleLift: ...@@ -196,7 +196,7 @@ class TestDimshuffleLift:
str(g) str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))" == "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g) ), str(g)
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g) assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
...@@ -209,7 +209,7 @@ class TestDimshuffleLift: ...@@ -209,7 +209,7 @@ class TestDimshuffleLift:
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}" "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))" "(InplaceDimShuffle{0,x,1}(x))))"
), str(g) ), str(g)
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)", str(g) assert str(g) == "FunctionGraph(x)", str(g)
# no need to check_stack_trace as graph is supposed to be empty # no need to check_stack_trace as graph is supposed to be empty
...@@ -238,7 +238,7 @@ class TestDimshuffleLift: ...@@ -238,7 +238,7 @@ class TestDimshuffleLift:
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}" "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))" "(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
) )
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) in (opt_str_g_inplace, opt_str_g_noinplace), str(g) assert str(g) in (opt_str_g_inplace, opt_str_g_noinplace), str(g)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
...@@ -277,7 +277,7 @@ class TestDimshuffleLift: ...@@ -277,7 +277,7 @@ class TestDimshuffleLift:
e = ds(x, (0, 1)) e = ds(x, (0, 1))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e])
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))" assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert hasattr(g.outputs[0].tag, "trace") assert hasattr(g.outputs[0].tag, "trace")
...@@ -294,7 +294,7 @@ class TestDimshuffleLift: ...@@ -294,7 +294,7 @@ class TestDimshuffleLift:
str(g) str(g)
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" == "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
) )
dimshuffle_lift.optimize(g) dimshuffle_lift.rewrite(g)
assert ( assert (
str(g) str(g)
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" == "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
...@@ -331,7 +331,7 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -331,7 +331,7 @@ def test_local_useless_dimshuffle_in_reshape():
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))" "Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
) )
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.optimize(g) useless_dimshuffle_in_reshape.rewrite(g)
assert str(g) == ( assert str(g) == (
"FunctionGraph(Reshape{1}(vector, Shape(vector)), " "FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), " "Reshape{2}(mat, Shape(mat)), "
...@@ -347,7 +347,7 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -347,7 +347,7 @@ def test_local_useless_dimshuffle_in_reshape():
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2]) h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h) str_h = str(h)
useless_dimshuffle_in_reshape.optimize(h) useless_dimshuffle_in_reshape.rewrite(h)
assert str(h) == str_h assert str(h) == str_h
...@@ -1505,7 +1505,7 @@ class TestLocalCanonicalizeAlloc: ...@@ -1505,7 +1505,7 @@ class TestLocalCanonicalizeAlloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort()) assert any(isinstance(node.op, Alloc) for node in g.toposort())
alloc_lift = out2in(local_alloc_sink_dimshuffle) alloc_lift = out2in(local_alloc_sink_dimshuffle)
alloc_lift.optimize(g) alloc_lift.rewrite(g)
if has_alloc: if has_alloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort()) assert any(isinstance(node.op, Alloc) for node in g.toposort())
...@@ -2849,8 +2849,8 @@ class TestLocalReshapeToDimshuffle: ...@@ -2849,8 +2849,8 @@ class TestLocalReshapeToDimshuffle:
"TensorConstant{[1 5 1 6 1 1]}))" "TensorConstant{[1 5 1 6 1 1]}))"
) )
reshape_lift.optimize(g) reshape_lift.rewrite(g)
useless_reshape.optimize(g) useless_reshape.rewrite(g)
assert str(g) == ( assert str(g) == (
"FunctionGraph(InplaceDimShuffle{x,0}" "FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, (None,))>), " "(<TensorType(float64, (None,))>), "
...@@ -2880,9 +2880,9 @@ def test_local_reshape_lift(): ...@@ -2880,9 +2880,9 @@ def test_local_reshape_lift():
class TestLiftTransposeThroughDot: class TestLiftTransposeThroughDot:
def simple_optimize(self, g): def simple_optimize(self, g):
out2in(local_useless_elemwise).optimize(g) out2in(local_useless_elemwise).rewrite(g)
out2in(local_lift_transpose_through_dot).optimize(g) out2in(local_lift_transpose_through_dot).rewrite(g)
out2in(local_useless_elemwise).optimize(g) out2in(local_useless_elemwise).rewrite(g)
return g return g
def test_matrix_matrix(self): def test_matrix_matrix(self):
...@@ -3159,9 +3159,9 @@ def test_local_useless_alloc(): ...@@ -3159,9 +3159,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w) output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output]) g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g) useless_alloc.rewrite(g)
merge_alloc.optimize(g) merge_alloc.rewrite(g)
useless_alloc.optimize(g) useless_alloc.rewrite(g)
topo = g.toposort() topo = g.toposort()
assert len(topo) == 1 assert len(topo) == 1
...@@ -3172,9 +3172,9 @@ def test_local_useless_alloc(): ...@@ -3172,9 +3172,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w) output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output]) g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g) useless_alloc.rewrite(g)
merge_alloc.optimize(g) merge_alloc.rewrite(g)
useless_alloc.optimize(g) useless_alloc.rewrite(g)
topo = g.toposort() topo = g.toposort()
assert len(topo) == 1 assert len(topo) == 1
...@@ -3186,9 +3186,9 @@ def test_local_useless_alloc(): ...@@ -3186,9 +3186,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w) output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w)
g = FunctionGraph([m, x, y, y2, z, w], [output]) g = FunctionGraph([m, x, y, y2, z, w], [output])
useless_alloc.optimize(g) useless_alloc.rewrite(g)
merge_alloc.optimize(g) merge_alloc.rewrite(g)
useless_alloc.optimize(g) useless_alloc.rewrite(g)
topo = g.toposort() topo = g.toposort()
assert len(topo) == 3 assert len(topo) == 3
......
...@@ -150,11 +150,11 @@ def ds(x, y): ...@@ -150,11 +150,11 @@ def ds(x, y):
def optimize(g, level="fast_run"): def optimize(g, level="fast_run"):
if level == "fast_run": if level == "fast_run":
_optimizer_fast_run.optimize(g) _optimizer_fast_run.rewrite(g)
elif level == "specialize": elif level == "specialize":
_optimizer_specialize.optimize(g) _optimizer_specialize.rewrite(g)
elif level == "stabilize": elif level == "stabilize":
_optimizer_stabilize.optimize(g) _optimizer_stabilize.rewrite(g)
else: else:
raise ValueError(level) raise ValueError(level)
return g return g
...@@ -189,19 +189,19 @@ class TestGreedyDistribute: ...@@ -189,19 +189,19 @@ class TestGreedyDistribute:
# 1. ((a/x + b/y) * x * y) --> a*y + b*x # 1. ((a/x + b/y) * x * y) --> a*y + b*x
e = (a / z + b / x) * x * z e = (a / z + b / x) * x * z
g = FunctionGraph([a, b, c, d, x, y, z], [e]) g = FunctionGraph([a, b, c, d, x, y, z], [e])
mul_canonizer.optimize(g) mul_canonizer.rewrite(g)
WalkingGraphRewriter( WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in" SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g) ).rewrite(g)
assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))" assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))"
# 2. ((a/x + b) * x) --> a + b*x # 2. ((a/x + b) * x) --> a + b*x
e = (a / x + b) * x e = (a / x + b) * x
g = FunctionGraph([a, b, x], [e]) g = FunctionGraph([a, b, x], [e])
mul_canonizer.optimize(g) mul_canonizer.rewrite(g)
WalkingGraphRewriter( WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in" SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g) ).rewrite(g)
assert str(pprint(g.outputs[0])) == "(a + (b * x))" assert str(pprint(g.outputs[0])) == "(a + (b * x))"
def test_kording_bug(self): def test_kording_bug(self):
...@@ -3054,7 +3054,7 @@ class TestLocalErfc: ...@@ -3054,7 +3054,7 @@ class TestLocalErfc:
WalkingGraphRewriter( WalkingGraphRewriter(
SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in" SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg) ).rewrite(fg)
# Make sure that the graph hasn't been changed # Make sure that the graph hasn't been changed
assert fg.outputs[0] is no_match assert fg.outputs[0] is no_match
......
...@@ -72,7 +72,7 @@ def test_merge_with_weird_eq(): ...@@ -72,7 +72,7 @@ def test_merge_with_weird_eq():
x = at.constant(np.asarray(1), name="x") x = at.constant(np.asarray(1), name="x")
y = at.constant(np.asarray(1), name="y") y = at.constant(np.asarray(1), name="y")
g = FunctionGraph([x, y], [x + y]) g = FunctionGraph([x, y], [x + y])
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1 assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0] node = list(g.apply_nodes)[0]
...@@ -84,7 +84,7 @@ def test_merge_with_weird_eq(): ...@@ -84,7 +84,7 @@ def test_merge_with_weird_eq():
x = at.constant(np.ones(5), name="x") x = at.constant(np.ones(5), name="x")
y = at.constant(np.ones(5), name="y") y = at.constant(np.ones(5), name="y")
g = FunctionGraph([x, y], [x + y]) g = FunctionGraph([x, y], [x + y])
MergeOptimizer().optimize(g) MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1 assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0] node = list(g.apply_nodes)[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论