提交 45f48ae6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename GraphRewriter.optimize to GraphRewriter.rewrite

上级 8db93fe1
......@@ -92,16 +92,16 @@ class Rewriter(abc.ABC):
class GraphRewriter(Rewriter):
"""A optimizer that can be applied to a `FunctionGraph` in order to transform it.
"""A rewriter that can be applied to a `FunctionGraph` in order to transform it.
It can represent an optimization or, in general, any kind of transformation
one could apply to a `FunctionGraph`.
This class represents a generalized rewrite that includes the way a graph
is traversed and/or changed as a whole.
"""
@abc.abstractmethod
def apply(self, fgraph):
"""Apply the optimization to a `FunctionGraph`.
"""Apply the rewriter to a `FunctionGraph`.
It may use all the methods defined by the `FunctionGraph`. If the
`GraphRewriter` needs to use a certain tool, such as an
......@@ -110,26 +110,29 @@ class GraphRewriter(Rewriter):
"""
raise NotImplementedError()
def optimize(self, fgraph, *args, **kwargs):
def optimize(self, *args, **kwargs):
warnings.warn(
"`GraphRewriter.optimize` is deprecated; use `GraphRewriter.rewrite` instead.",
DeprecationWarning,
stacklevel=2,
)
self.rewrite(*args, **kwargs)
def rewrite(self, fgraph, *args, **kwargs):
"""
This is meant as a shortcut for the following::
opt.add_requirements(fgraph)
opt.apply(fgraph)
self.add_requirements(fgraph)
self.apply(fgraph)
"""
self.add_requirements(fgraph)
ret = self.apply(fgraph, *args, **kwargs)
return ret
return self.apply(fgraph, *args, **kwargs)
def __call__(self, fgraph):
"""Optimize a `FunctionGraph`.
This is the same as ``self.optimize(fgraph)``.
"""
return self.optimize(fgraph)
"""Rewrite a `FunctionGraph`."""
return self.rewrite(fgraph)
def add_requirements(self, fgraph):
...
......@@ -141,12 +144,12 @@ class GraphRewriter(Rewriter):
file=stream,
)
@staticmethod
def print_profile(stream, prof, level=0):
@classmethod
def print_profile(cls, stream, prof, level=0):
if prof is not None:
raise NotImplementedError(
"The function print_profile must be overridden if the"
" optimizer return profiling information."
"The function `print_profile` must be overridden when the"
" rewriter returns profiling information."
)
......
......@@ -44,10 +44,10 @@ def optimize_graph(
return_only_out = True
canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph)
_ = canonicalize_opt.rewrite(fgraph)
if custom_opt:
custom_opt.optimize(fgraph)
custom_opt.rewrite(fgraph)
if return_only_out:
return fgraph.outputs[0]
......@@ -79,7 +79,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
# Perform merge optimization.
MergeOptimizer().optimize(fgraph)
MergeOptimizer().rewrite(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# We need to be careful with the special case where the owner is None,
......
......@@ -4152,7 +4152,7 @@ class Composite(ScalarOp):
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().optimize(fgraph)
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise ValueError(
......
......@@ -56,7 +56,7 @@ Graph Rewriting
<libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed
for the :meth:`GraphRewriter.apply` method to do its job properly.
.. method:: optimize(fgraph)
.. method:: rewrite(fgraph)
This is the interface function called by Aesara. It calls
:meth:`GraphRewriter.apply` by default.
......@@ -159,7 +159,7 @@ Now, we test the optimization:
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify.optimize(e)
>>> simplify.rewrite(e)
>>> e
FunctionGraph(add(z, mul(x, true_div(z, x))))
......@@ -175,7 +175,7 @@ optimization you wrote. For example, consider the following:
>>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a])
>>> e
FunctionGraph(true_div(mul(add(y, z), x), add(y, z)))
>>> simplify.optimize(e)
>>> simplify.rewrite(e)
>>> e
FunctionGraph(true_div(mul(add(y, z), x), add(y, z)))
......@@ -186,11 +186,11 @@ computation, using the :class:`MergeOptimizer` defined in
:mod:`aesara.graph.opt`.
>>> from aesara.graph.opt import MergeOptimizer
>>> MergeOptimizer().optimize(e) # doctest: +ELLIPSIS
>>> MergeOptimizer().rewrite(e) # doctest: +ELLIPSIS
(0, ..., None, None, {}, 1, 0)
>>> e
FunctionGraph(true_div(mul(*1 -> add(y, z), x), *1))
>>> simplify.optimize(e)
>>> simplify.rewrite(e)
>>> e
FunctionGraph(x)
......@@ -265,7 +265,7 @@ subset of them) and applies one or several local optimizers.
>>> e
FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> simplify = aesara.graph.opt.WalkingGraphRewriter(local_simplify)
>>> simplify.optimize(e)
>>> simplify.rewrite(e)
(<aesara.graph.opt.WalkingGraphRewriter object at 0x...>, 1, 5, 3, ..., ..., ...)
>>> e
FunctionGraph(add(z, mul(x, true_div(z, x))))
......
......@@ -151,7 +151,7 @@ def test_misc():
e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = create_fgraph([x, y, z], [e])
assert g.consistent()
PatternOptimizer((transpose_view, (transpose_view, "x")), "x").optimize(g)
PatternOptimizer((transpose_view, (transpose_view, "x")), "x").rewrite(g)
assert str(g) == "FunctionGraph(x)"
new_e = add(x, y)
g.replace_validate(x, new_e)
......@@ -330,7 +330,7 @@ def test_long_destroyers_loop():
e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x))
g = create_fgraph([x, y, z], [e])
assert g.consistent()
TopoSubstitutionNodeRewriter(add, add_in_place).optimize(g)
TopoSubstitutionNodeRewriter(add, add_in_place).rewrite(g)
assert g.consistent()
# we don't want to see that!
assert (
......@@ -366,7 +366,7 @@ def test_multi_destroyers_through_views():
g = create_fgraph([x, y, z], [e])
assert g.consistent()
fail = FailureWatch()
TopoSubstitutionNodeRewriter(add, add_in_place, fail).optimize(g)
TopoSubstitutionNodeRewriter(add, add_in_place, fail).rewrite(g)
assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once
......@@ -388,7 +388,7 @@ def test_usage_loop():
g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent()
# replace add_in_place with add
TopoSubstitutionNodeRewriter(add_in_place, add).optimize(g)
TopoSubstitutionNodeRewriter(add_in_place, add).rewrite(g)
assert g.consistent()
......@@ -409,7 +409,7 @@ def test_usage_loop_insert_views():
g = create_fgraph([x, y, z], [e])
assert g.consistent()
fail = FailureWatch()
TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).optimize(g)
TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).rewrite(g)
assert g.consistent()
# it must keep one sigmoid in the long sigmoid chain
assert fail.failures == 1
......@@ -454,19 +454,19 @@ def test_multiple_inplace():
# try to work in-place on x/0 and y/1 (this should fail)
fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).rewrite(g)
assert g.consistent()
assert fail.failures == 1
# try to work in-place on x/0 (this should fail)
fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).rewrite(g)
assert g.consistent()
assert fail.failures == 1
# try to work in-place on y/1 (this should succeed)
fail = FailureWatch()
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).rewrite(g)
assert g.consistent()
assert fail.failures == 0
......@@ -474,6 +474,6 @@ def test_multiple_inplace():
fail = FailureWatch()
TopoSubstitutionNodeRewriter(
multiple_in_place_1, multiple_in_place_0_1, fail
).optimize(g)
).rewrite(g)
assert g.consistent()
assert fail.failures == 1
差异被折叠。
......@@ -824,7 +824,7 @@ class TestScan:
assert scan_c is not scan_a
g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False)
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
scan_a_out, scan_b_out, scan_c_out = g.outputs
assert scan_a_out is scan_b_out
......
......@@ -340,7 +340,7 @@ class TestLogSoftmax(utt.InferShapeTester):
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
......@@ -647,7 +647,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_w_bias(self):
......@@ -659,7 +659,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 1
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
......@@ -676,7 +676,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
)
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 2
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
......@@ -694,7 +694,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy],
)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias not in ops
......@@ -717,7 +717,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions:
fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4
......@@ -726,7 +726,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Also verify the gradient wrt x
fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2
......@@ -744,14 +744,14 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr, x])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 # [big_op, sum]
assert crossentropy_softmax_argmax_1hot_with_bias in ops
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2
......@@ -770,7 +770,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in mean_expressions:
fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 6
......@@ -778,7 +778,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5
......@@ -798,7 +798,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in mean_bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4
......@@ -806,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5
......@@ -827,7 +827,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions:
fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5
......@@ -836,7 +836,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Also verify the gradient wrt x
fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 3
......@@ -888,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for expr in expressions:
fgraph = FunctionGraph([x, y, a], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 5 <= len(fgraph.toposort()) <= 10
......@@ -898,7 +898,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# Verify the gradient wrt x
fgraph = FunctionGraph([x, y, a], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 3 <= len(fgraph.toposort()) <= 6
......@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
fgraph = FunctionGraph(
[x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})]
)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert 6 <= len(fgraph.toposort()) <= 8
......@@ -927,7 +927,7 @@ def test_argmax_pushdown():
# test that the max_and_argmax is pushed down if the max is not used
out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1]
fgraph = FunctionGraph([x], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
# print 'AFTER'
# for node in fgraph.toposort():
......@@ -942,7 +942,7 @@ def test_argmax_pushdown():
assert hasattr(fgraph.outputs[0].tag, "trace")
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
# print 'AFTER'
# for node in fgraph.toposort():
......@@ -963,7 +963,7 @@ def test_argmax_pushdown_bias():
out = argmax(softmax_with_bias(x, b), axis=-1)
fgraph = FunctionGraph([x, b], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
types_to_check = (DimShuffle, Elemwise, Argmax)
assert len(fgraph.toposort()) == 3
......@@ -977,7 +977,7 @@ def test_argmax_pushdown_bias():
out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
fgraph = FunctionGraph([x, b], [out])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
......
......@@ -159,11 +159,11 @@ def ds(x, y):
def optimize(g, level="fast_run"):
if level == "fast_run":
_optimizer_fast_run.optimize(g)
_optimizer_fast_run.rewrite(g)
elif level == "specialize":
_optimizer_specialize.optimize(g)
_optimizer_specialize.rewrite(g)
elif level == "stabilize":
_optimizer_stabilize.optimize(g)
_optimizer_stabilize.rewrite(g)
else:
raise ValueError(level)
return g
......@@ -184,7 +184,7 @@ class TestDimshuffleLift:
assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty
......@@ -196,7 +196,7 @@ class TestDimshuffleLift:
str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
......@@ -209,7 +209,7 @@ class TestDimshuffleLift:
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
), str(g)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)", str(g)
# no need to check_stack_trace as graph is supposed to be empty
......@@ -238,7 +238,7 @@ class TestDimshuffleLift:
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) in (opt_str_g_inplace, opt_str_g_noinplace), str(g)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
......@@ -277,7 +277,7 @@ class TestDimshuffleLift:
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
# Check stacktrace was copied over correctly after opt was applied
assert hasattr(g.outputs[0].tag, "trace")
......@@ -294,7 +294,7 @@ class TestDimshuffleLift:
str(g)
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
dimshuffle_lift.optimize(g)
dimshuffle_lift.rewrite(g)
assert (
str(g)
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
......@@ -331,7 +331,7 @@ def test_local_useless_dimshuffle_in_reshape():
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.optimize(g)
useless_dimshuffle_in_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
......@@ -347,7 +347,7 @@ def test_local_useless_dimshuffle_in_reshape():
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h)
useless_dimshuffle_in_reshape.optimize(h)
useless_dimshuffle_in_reshape.rewrite(h)
assert str(h) == str_h
......@@ -1505,7 +1505,7 @@ class TestLocalCanonicalizeAlloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort())
alloc_lift = out2in(local_alloc_sink_dimshuffle)
alloc_lift.optimize(g)
alloc_lift.rewrite(g)
if has_alloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort())
......@@ -2849,8 +2849,8 @@ class TestLocalReshapeToDimshuffle:
"TensorConstant{[1 5 1 6 1 1]}))"
)
reshape_lift.optimize(g)
useless_reshape.optimize(g)
reshape_lift.rewrite(g)
useless_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, (None,))>), "
......@@ -2880,9 +2880,9 @@ def test_local_reshape_lift():
class TestLiftTransposeThroughDot:
def simple_optimize(self, g):
out2in(local_useless_elemwise).optimize(g)
out2in(local_lift_transpose_through_dot).optimize(g)
out2in(local_useless_elemwise).optimize(g)
out2in(local_useless_elemwise).rewrite(g)
out2in(local_lift_transpose_through_dot).rewrite(g)
out2in(local_useless_elemwise).rewrite(g)
return g
def test_matrix_matrix(self):
......@@ -3159,9 +3159,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
useless_alloc.rewrite(g)
merge_alloc.rewrite(g)
useless_alloc.rewrite(g)
topo = g.toposort()
assert len(topo) == 1
......@@ -3172,9 +3172,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
useless_alloc.rewrite(g)
merge_alloc.rewrite(g)
useless_alloc.rewrite(g)
topo = g.toposort()
assert len(topo) == 1
......@@ -3186,9 +3186,9 @@ def test_local_useless_alloc():
output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w)
g = FunctionGraph([m, x, y, y2, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
useless_alloc.rewrite(g)
merge_alloc.rewrite(g)
useless_alloc.rewrite(g)
topo = g.toposort()
assert len(topo) == 3
......
......@@ -150,11 +150,11 @@ def ds(x, y):
def optimize(g, level="fast_run"):
if level == "fast_run":
_optimizer_fast_run.optimize(g)
_optimizer_fast_run.rewrite(g)
elif level == "specialize":
_optimizer_specialize.optimize(g)
_optimizer_specialize.rewrite(g)
elif level == "stabilize":
_optimizer_stabilize.optimize(g)
_optimizer_stabilize.rewrite(g)
else:
raise ValueError(level)
return g
......@@ -189,19 +189,19 @@ class TestGreedyDistribute:
# 1. ((a/x + b/y) * x * y) --> a*y + b*x
e = (a / z + b / x) * x * z
g = FunctionGraph([a, b, c, d, x, y, z], [e])
mul_canonizer.optimize(g)
mul_canonizer.rewrite(g)
WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g)
).rewrite(g)
assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))"
# 2. ((a/x + b) * x) --> a + b*x
e = (a / x + b) * x
g = FunctionGraph([a, b, x], [e])
mul_canonizer.optimize(g)
mul_canonizer.rewrite(g)
WalkingGraphRewriter(
SequentialNodeRewriter(local_greedy_distributor), order="out_to_in"
).optimize(g)
).rewrite(g)
assert str(pprint(g.outputs[0])) == "(a + (b * x))"
def test_kording_bug(self):
......@@ -3054,7 +3054,7 @@ class TestLocalErfc:
WalkingGraphRewriter(
SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg)
).rewrite(fg)
# Make sure that the graph hasn't been changed
assert fg.outputs[0] is no_match
......
......@@ -72,7 +72,7 @@ def test_merge_with_weird_eq():
x = at.constant(np.asarray(1), name="x")
y = at.constant(np.asarray(1), name="y")
g = FunctionGraph([x, y], [x + y])
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0]
......@@ -84,7 +84,7 @@ def test_merge_with_weird_eq():
x = at.constant(np.ones(5), name="x")
y = at.constant(np.ones(5), name="y")
g = FunctionGraph([x, y], [x + y])
MergeOptimizer().optimize(g)
MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论