提交 2d0057d4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename OpSub to SubstitutionNodeRewriter

上级 66974655
......@@ -1396,7 +1396,7 @@ class SequentialNodeRewriter(NodeRewriter):
opt.add_requirements(fgraph)
class OpSub(NodeRewriter):
class SubstitutionNodeRewriter(NodeRewriter):
"""
Replaces the application of a certain `Op` by the application of
......@@ -1411,12 +1411,12 @@ class OpSub(NodeRewriter):
Examples
--------
OpSub(add, sub) ==>
SubstitutionNodeRewriter(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
# an OpSub does not apply to the nodes it produces
# an SubstitutionNodeRewriter does not apply to the nodes it produces
reentrant = False
# all the inputs of the original node are transferred to the outputs
retains_inputs = True
......@@ -3173,6 +3173,11 @@ DEPRECATED_NAMES = [
"`LocalOptGroup` is deprecated: use `SequentialNodeRewriter` instead.",
SequentialNodeRewriter,
),
(
"OpSub",
"`OpSub` is deprecated: use `SubstitutionNodeRewriter` instead.",
SubstitutionNodeRewriter,
),
]
......
......@@ -270,12 +270,12 @@ FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> e
FunctionGraph(add(z, mul(x, true_div(z, x))))
:class:`OpSub`, :class:`OpRemove`, :class:`PatternSub`
++++++++++++++++++++++++++++++++++++++++++++++++++++++
:class:`SubstitutionNodeRewriter`, :class:`OpRemove`, :class:`PatternSub`
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
.. function:: OpSub(op1, op2)
.. function:: SubstitutionNodeRewriter(op1, op2)
Replaces all uses of ``op1`` by ``op2``. In other
words, the outputs of all :class:`Apply` nodes using ``op1`` by the outputs
......@@ -296,11 +296,11 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
.. code::
from aesara.scalar import identity
from aesara.graph.opt import OpSub, OpRemove, PatternSub
from aesara.graph.opt import SubstitutionNodeRewriter, OpRemove, PatternSub
# Replacing `add` by `mul` (this is not recommended for primarily
# mathematical reasons):
add_to_mul = OpSub(add, mul)
add_to_mul = SubstitutionNodeRewriter(add, mul)
# Removing `identity`
remove_identity = OpRemove(identity)
......@@ -313,12 +313,12 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
.. note::
:class:`OpSub`, :class:`OpRemove` and :class:`PatternSub` produce local optimizers, which
:class:`SubstitutionNodeRewriter`, :class:`OpRemove` and :class:`PatternSub` produce local optimizers, which
means that everything we said previously about local optimizers
apply (e.g. they need to be wrapped in a :class:`NavigatorOptimizer`, etc.)
When an optimization can be naturally expressed using :class:`OpSub`, :class:`OpRemove`
When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`OpRemove`
or :class:`PatternSub`, it is highly recommended to use them.
.. _unification:
......
......@@ -11,8 +11,8 @@ from aesara.graph.op import Op
from aesara.graph.opt import (
NavigatorOptimizer,
OpKeyOptimizer,
OpSub,
PatternSub,
SubstitutionNodeRewriter,
TopoOptimizer,
)
from aesara.graph.type import Type
......@@ -24,8 +24,12 @@ def PatternOptimizer(p1, p2, ign=True):
return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
def OpSubOptimizer(op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True):
return TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback=fail)
def TopoSubstitutionNodeRewriter(
op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True
):
return TopoOptimizer(
SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail
)
def as_variable(x):
......@@ -127,7 +131,7 @@ def create_fgraph(inputs, outputs, validate=True):
class FailureWatch:
# when passed to OpSubOptimizer or PatternOptimizer, counts the
# when passed to SubstitutionNodeRewriter or PatternOptimizer, counts the
# number of failures
def __init__(self):
self.failures = 0
......@@ -326,7 +330,7 @@ def test_long_destroyers_loop():
e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x))
g = create_fgraph([x, y, z], [e])
assert g.consistent()
OpSubOptimizer(add, add_in_place).optimize(g)
TopoSubstitutionNodeRewriter(add, add_in_place).optimize(g)
assert g.consistent()
# we don't want to see that!
assert (
......@@ -362,7 +366,7 @@ def test_multi_destroyers_through_views():
g = create_fgraph([x, y, z], [e])
assert g.consistent()
fail = FailureWatch()
OpSubOptimizer(add, add_in_place, fail).optimize(g)
TopoSubstitutionNodeRewriter(add, add_in_place, fail).optimize(g)
assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once
......@@ -384,7 +388,7 @@ def test_usage_loop():
g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent()
# replace add_in_place with add
OpSubOptimizer(add_in_place, add).optimize(g)
TopoSubstitutionNodeRewriter(add_in_place, add).optimize(g)
assert g.consistent()
......@@ -405,7 +409,7 @@ def test_usage_loop_insert_views():
g = create_fgraph([x, y, z], [e])
assert g.consistent()
fail = FailureWatch()
OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g)
TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).optimize(g)
assert g.consistent()
# it must keep one sigmoid in the long sigmoid chain
assert fail.failures == 1
......@@ -450,24 +454,26 @@ def test_multiple_inplace():
# try to work in-place on x/0 and y/1 (this should fail)
fail = FailureWatch()
OpSubOptimizer(multiple, multiple_in_place_0_1, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).optimize(g)
assert g.consistent()
assert fail.failures == 1
# try to work in-place on x/0 (this should fail)
fail = FailureWatch()
OpSubOptimizer(multiple, multiple_in_place_0, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).optimize(g)
assert g.consistent()
assert fail.failures == 1
# try to work in-place on y/1 (this should succeed)
fail = FailureWatch()
OpSubOptimizer(multiple, multiple_in_place_1, fail).optimize(g)
TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).optimize(g)
assert g.consistent()
assert fail.failures == 0
# try to work in-place on x/0 and y/1 (this should still fail)
fail = FailureWatch()
OpSubOptimizer(multiple_in_place_1, multiple_in_place_0_1, fail).optimize(g)
TopoSubstitutionNodeRewriter(
multiple_in_place_1, multiple_in_place_0_1, fail
).optimize(g)
assert g.consistent()
assert fail.failures == 1
......@@ -9,10 +9,10 @@ from aesara.graph.opt import (
EquilibriumOptimizer,
MergeOptimizer,
OpKeyOptimizer,
OpSub,
OpToRewriterTracker,
PatternSub,
SequentialNodeRewriter,
SubstitutionNodeRewriter,
TopoOptimizer,
in2out,
logging,
......@@ -223,23 +223,23 @@ class TestPatternOptimizer:
assert str_g == "FunctionGraph(Op4(z, y))"
def OpSubOptimizer(op1, op2):
return OpKeyOptimizer(OpSub(op1, op2))
def KeyedSubstitutionNodeRewriter(op1, op2):
return OpKeyOptimizer(SubstitutionNodeRewriter(op1, op2))
class TestOpSubOptimizer:
class TestSubstitutionNodeRewriter:
def test_straightforward(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op1, op2).optimize(g)
KeyedSubstitutionNodeRewriter(op1, op2).optimize(g)
assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
def test_straightforward_2(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x), op3(y), op4(z))
g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op3, op4).optimize(g)
KeyedSubstitutionNodeRewriter(op3, op4).optimize(g)
assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论