提交 2d0057d4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename OpSub to SubstitutionNodeRewriter

上级 66974655
...@@ -1396,7 +1396,7 @@ class SequentialNodeRewriter(NodeRewriter): ...@@ -1396,7 +1396,7 @@ class SequentialNodeRewriter(NodeRewriter):
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
class OpSub(NodeRewriter): class SubstitutionNodeRewriter(NodeRewriter):
""" """
Replaces the application of a certain `Op` by the application of Replaces the application of a certain `Op` by the application of
...@@ -1411,12 +1411,12 @@ class OpSub(NodeRewriter): ...@@ -1411,12 +1411,12 @@ class OpSub(NodeRewriter):
Examples Examples
-------- --------
OpSub(add, sub) ==> SubstitutionNodeRewriter(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
""" """
# an OpSub does not apply to the nodes it produces # an SubstitutionNodeRewriter does not apply to the nodes it produces
reentrant = False reentrant = False
# all the inputs of the original node are transferred to the outputs # all the inputs of the original node are transferred to the outputs
retains_inputs = True retains_inputs = True
...@@ -3173,6 +3173,11 @@ DEPRECATED_NAMES = [ ...@@ -3173,6 +3173,11 @@ DEPRECATED_NAMES = [
"`LocalOptGroup` is deprecated: use `SequentialNodeRewriter` instead.", "`LocalOptGroup` is deprecated: use `SequentialNodeRewriter` instead.",
SequentialNodeRewriter, SequentialNodeRewriter,
), ),
(
"OpSub",
"`OpSub` is deprecated: use `SubstitutionNodeRewriter` instead.",
SubstitutionNodeRewriter,
),
] ]
......
...@@ -270,12 +270,12 @@ FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) ...@@ -270,12 +270,12 @@ FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> e >>> e
FunctionGraph(add(z, mul(x, true_div(z, x)))) FunctionGraph(add(z, mul(x, true_div(z, x))))
:class:`OpSub`, :class:`OpRemove`, :class:`PatternSub` :class:`SubstitutionNodeRewriter`, :class:`OpRemove`, :class:`PatternSub`
++++++++++++++++++++++++++++++++++++++++++++++++++++++ +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aesara defines some shortcuts to make :class:`NodeRewriter`\s: Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
.. function:: OpSub(op1, op2) .. function:: SubstitutionNodeRewriter(op1, op2)
Replaces all uses of ``op1`` by ``op2``. In other Replaces all uses of ``op1`` by ``op2``. In other
words, the outputs of all :class:`Apply` nodes using ``op1`` by the outputs words, the outputs of all :class:`Apply` nodes using ``op1`` by the outputs
...@@ -296,11 +296,11 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s: ...@@ -296,11 +296,11 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
.. code:: .. code::
from aesara.scalar import identity from aesara.scalar import identity
from aesara.graph.opt import OpSub, OpRemove, PatternSub from aesara.graph.opt import SubstitutionNodeRewriter, OpRemove, PatternSub
# Replacing `add` by `mul` (this is not recommended for primarily # Replacing `add` by `mul` (this is not recommended for primarily
# mathematical reasons): # mathematical reasons):
add_to_mul = OpSub(add, mul) add_to_mul = SubstitutionNodeRewriter(add, mul)
# Removing `identity` # Removing `identity`
remove_identity = OpRemove(identity) remove_identity = OpRemove(identity)
...@@ -313,12 +313,12 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s: ...@@ -313,12 +313,12 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
.. note:: .. note::
:class:`OpSub`, :class:`OpRemove` and :class:`PatternSub` produce local optimizers, which :class:`SubstitutionNodeRewriter`, :class:`OpRemove` and :class:`PatternSub` produce local optimizers, which
means that everything we said previously about local optimizers means that everything we said previously about local optimizers
apply (e.g. they need to be wrapped in a :class:`NavigatorOptimizer`, etc.) apply (e.g. they need to be wrapped in a :class:`NavigatorOptimizer`, etc.)
When an optimization can be naturally expressed using :class:`OpSub`, :class:`OpRemove` When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`OpRemove`
or :class:`PatternSub`, it is highly recommended to use them. or :class:`PatternSub`, it is highly recommended to use them.
.. _unification: .. _unification:
......
...@@ -11,8 +11,8 @@ from aesara.graph.op import Op ...@@ -11,8 +11,8 @@ from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
NavigatorOptimizer, NavigatorOptimizer,
OpKeyOptimizer, OpKeyOptimizer,
OpSub,
PatternSub, PatternSub,
SubstitutionNodeRewriter,
TopoOptimizer, TopoOptimizer,
) )
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -24,8 +24,12 @@ def PatternOptimizer(p1, p2, ign=True): ...@@ -24,8 +24,12 @@ def PatternOptimizer(p1, p2, ign=True):
return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
def OpSubOptimizer(op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True): def TopoSubstitutionNodeRewriter(
return TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback=fail) op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True
):
return TopoOptimizer(
SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail
)
def as_variable(x): def as_variable(x):
...@@ -127,7 +131,7 @@ def create_fgraph(inputs, outputs, validate=True): ...@@ -127,7 +131,7 @@ def create_fgraph(inputs, outputs, validate=True):
class FailureWatch: class FailureWatch:
# when passed to OpSubOptimizer or PatternOptimizer, counts the # when passed to SubstitutionNodeRewriter or PatternOptimizer, counts the
# number of failures # number of failures
def __init__(self): def __init__(self):
self.failures = 0 self.failures = 0
...@@ -326,7 +330,7 @@ def test_long_destroyers_loop(): ...@@ -326,7 +330,7 @@ def test_long_destroyers_loop():
e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x)) e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x))
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
OpSubOptimizer(add, add_in_place).optimize(g) TopoSubstitutionNodeRewriter(add, add_in_place).optimize(g)
assert g.consistent() assert g.consistent()
# we don't want to see that! # we don't want to see that!
assert ( assert (
...@@ -362,7 +366,7 @@ def test_multi_destroyers_through_views(): ...@@ -362,7 +366,7 @@ def test_multi_destroyers_through_views():
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(add, add_in_place, fail).optimize(g) TopoSubstitutionNodeRewriter(add, add_in_place, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once assert fail.failures == 1 # should have succeeded once and failed once
...@@ -384,7 +388,7 @@ def test_usage_loop(): ...@@ -384,7 +388,7 @@ def test_usage_loop():
g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False) g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent() assert not g.consistent()
# replace add_in_place with add # replace add_in_place with add
OpSubOptimizer(add_in_place, add).optimize(g) TopoSubstitutionNodeRewriter(add_in_place, add).optimize(g)
assert g.consistent() assert g.consistent()
...@@ -405,7 +409,7 @@ def test_usage_loop_insert_views(): ...@@ -405,7 +409,7 @@ def test_usage_loop_insert_views():
g = create_fgraph([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g) TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).optimize(g)
assert g.consistent() assert g.consistent()
# it must keep one sigmoid in the long sigmoid chain # it must keep one sigmoid in the long sigmoid chain
assert fail.failures == 1 assert fail.failures == 1
...@@ -450,24 +454,26 @@ def test_multiple_inplace(): ...@@ -450,24 +454,26 @@ def test_multiple_inplace():
# try to work in-place on x/0 and y/1 (this should fail) # try to work in-place on x/0 and y/1 (this should fail)
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(multiple, multiple_in_place_0_1, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
# try to work in-place on x/0 (this should fail) # try to work in-place on x/0 (this should fail)
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(multiple, multiple_in_place_0, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
# try to work in-place on y/1 (this should succeed) # try to work in-place on y/1 (this should succeed)
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(multiple, multiple_in_place_1, fail).optimize(g) TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 0 assert fail.failures == 0
# try to work in-place on x/0 and y/1 (this should still fail) # try to work in-place on x/0 and y/1 (this should still fail)
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(multiple_in_place_1, multiple_in_place_0_1, fail).optimize(g) TopoSubstitutionNodeRewriter(
multiple_in_place_1, multiple_in_place_0_1, fail
).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 assert fail.failures == 1
...@@ -9,10 +9,10 @@ from aesara.graph.opt import ( ...@@ -9,10 +9,10 @@ from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
MergeOptimizer, MergeOptimizer,
OpKeyOptimizer, OpKeyOptimizer,
OpSub,
OpToRewriterTracker, OpToRewriterTracker,
PatternSub, PatternSub,
SequentialNodeRewriter, SequentialNodeRewriter,
SubstitutionNodeRewriter,
TopoOptimizer, TopoOptimizer,
in2out, in2out,
logging, logging,
...@@ -223,23 +223,23 @@ class TestPatternOptimizer: ...@@ -223,23 +223,23 @@ class TestPatternOptimizer:
assert str_g == "FunctionGraph(Op4(z, y))" assert str_g == "FunctionGraph(Op4(z, y))"
def OpSubOptimizer(op1, op2): def KeyedSubstitutionNodeRewriter(op1, op2):
return OpKeyOptimizer(OpSub(op1, op2)) return OpKeyOptimizer(SubstitutionNodeRewriter(op1, op2))
class TestOpSubOptimizer: class TestSubstitutionNodeRewriter:
def test_straightforward(self): def test_straightforward(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op1, op2).optimize(g) KeyedSubstitutionNodeRewriter(op1, op2).optimize(g)
assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))" assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
def test_straightforward_2(self): def test_straightforward_2(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x), op3(y), op4(z)) e = op1(op2(x), op3(y), op4(z))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op3, op4).optimize(g) KeyedSubstitutionNodeRewriter(op3, op4).optimize(g)
assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))" assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论