提交 6302cef1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename PatternSub to PatternNodeRewriter

上级 371c5447
...@@ -1479,7 +1479,7 @@ class RemovalNodeRewriter(NodeRewriter): ...@@ -1479,7 +1479,7 @@ class RemovalNodeRewriter(NodeRewriter):
) )
class PatternSub(NodeRewriter): class PatternNodeRewriter(NodeRewriter):
"""Replace all occurrences of an input pattern with an output pattern. """Replace all occurrences of an input pattern with an output pattern.
The input and output patterns have the following syntax: The input and output patterns have the following syntax:
...@@ -1517,7 +1517,7 @@ class PatternSub(NodeRewriter): ...@@ -1517,7 +1517,7 @@ class PatternSub(NodeRewriter):
trying to match and returns True or False according to an trying to match and returns True or False according to an
arbitrary criterion. arbitrary criterion.
The constructor creates a `PatternSub` that replaces occurrences of The constructor creates a `PatternNodeRewriter` that replaces occurrences of
`in_pattern` by occurrences of `out_pattern`. `in_pattern` by occurrences of `out_pattern`.
Parameters Parameters
...@@ -1548,11 +1548,11 @@ class PatternSub(NodeRewriter): ...@@ -1548,11 +1548,11 @@ class PatternSub(NodeRewriter):
Examples Examples
-------- --------
PatternSub((add, 'x', 'y'), (add, 'y', 'x')) PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x')) PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x') PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x')) PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x', PatternNodeRewriter((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}), 'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x')) (scrabble, 'x'))
...@@ -3183,6 +3183,11 @@ DEPRECATED_NAMES = [ ...@@ -3183,6 +3183,11 @@ DEPRECATED_NAMES = [
"`OpRemove` is deprecated: use `RemovalNodeRewriter` instead.", "`OpRemove` is deprecated: use `RemovalNodeRewriter` instead.",
RemovalNodeRewriter, RemovalNodeRewriter,
), ),
(
"PatternSub",
"`PatternSub` is deprecated: use `PatternNodeRewriter` instead.",
PatternNodeRewriter,
),
] ]
......
...@@ -4,7 +4,7 @@ import aesara ...@@ -4,7 +4,7 @@ import aesara
import aesara.scalar as aes import aesara.scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.opt import PatternSub, TopoOptimizer, node_rewriter from aesara.graph.opt import PatternNodeRewriter, TopoOptimizer, node_rewriter
from aesara.link.c.op import COp, _NoPythonCOp from aesara.link.c.op import COp, _NoPythonCOp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.sparse import basic as sparse from aesara.sparse import basic as sparse
...@@ -928,7 +928,7 @@ usmm_csc_dense_inplace = UsmmCscDense(inplace=True) ...@@ -928,7 +928,7 @@ usmm_csc_dense_inplace = UsmmCscDense(inplace=True)
# This is tested in tests/test_basic.py:UsmmTests # This is tested in tests/test_basic.py:UsmmTests
local_usmm = PatternSub( local_usmm = PatternNodeRewriter(
( (
sub, sub,
"z", "z",
......
...@@ -11,7 +11,7 @@ import aesara.scalar.math as aes_math ...@@ -11,7 +11,7 @@ import aesara.scalar.math as aes_math
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import ( from aesara.graph.opt import (
NodeRewriter, NodeRewriter,
PatternSub, PatternNodeRewriter,
SequentialNodeRewriter, SequentialNodeRewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
...@@ -2512,7 +2512,7 @@ get_clients_at_depth1 = partial(get_clients_at_depth, depth=1) ...@@ -2512,7 +2512,7 @@ get_clients_at_depth1 = partial(get_clients_at_depth, depth=1)
get_clients_at_depth2 = partial(get_clients_at_depth, depth=2) get_clients_at_depth2 = partial(get_clients_at_depth, depth=2)
# 1+erf(x)=>erfc(-x) # 1+erf(x)=>erfc(-x)
local_one_plus_erf = PatternSub( local_one_plus_erf = PatternNodeRewriter(
(add, 1, (erf, "x")), (add, 1, (erf, "x")),
(erfc, (neg, "x")), (erfc, (neg, "x")),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2527,7 +2527,7 @@ register_specialize(local_one_plus_erf) ...@@ -2527,7 +2527,7 @@ register_specialize(local_one_plus_erf)
# Only one of the two rewrites below is needed if a canonicalization is added # Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y) # for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erf(x)=>erfc(x) # 1-erf(x)=>erfc(x)
local_one_minus_erf = PatternSub( local_one_minus_erf = PatternNodeRewriter(
(sub, 1, (erf, "x")), (sub, 1, (erf, "x")),
(erfc, "x"), (erfc, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2539,7 +2539,7 @@ register_canonicalize(local_one_minus_erf) ...@@ -2539,7 +2539,7 @@ register_canonicalize(local_one_minus_erf)
register_stabilize(local_one_minus_erf) register_stabilize(local_one_minus_erf)
register_specialize(local_one_minus_erf) register_specialize(local_one_minus_erf)
local_one_minus_erf2 = PatternSub( local_one_minus_erf2 = PatternNodeRewriter(
(add, 1, (neg, (erf, "x"))), (add, 1, (neg, (erf, "x"))),
(erfc, "x"), (erfc, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2554,7 +2554,7 @@ register_specialize(local_one_minus_erf2) ...@@ -2554,7 +2554,7 @@ register_specialize(local_one_minus_erf2)
# (-1)+erf(x) => -erfc(x) # (-1)+erf(x) => -erfc(x)
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will # There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will
# convert those to the matched pattern # convert those to the matched pattern
local_erf_minus_one = PatternSub( local_erf_minus_one = PatternNodeRewriter(
(add, -1, (erf, "x")), (add, -1, (erf, "x")),
(neg, (erfc, "x")), (neg, (erfc, "x")),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2569,7 +2569,7 @@ register_specialize(local_erf_minus_one) ...@@ -2569,7 +2569,7 @@ register_specialize(local_erf_minus_one)
# Only one of the two rewrites below is needed if a canonicalization is added # Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y) # for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erfc(x) => erf(x) # 1-erfc(x) => erf(x)
local_one_minus_erfc = PatternSub( local_one_minus_erfc = PatternNodeRewriter(
(sub, 1, (erfc, "x")), (sub, 1, (erfc, "x")),
(erf, "x"), (erf, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2581,7 +2581,7 @@ register_canonicalize(local_one_minus_erfc) ...@@ -2581,7 +2581,7 @@ register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc) register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc) register_specialize(local_one_minus_erfc)
local_one_minus_erfc2 = PatternSub( local_one_minus_erfc2 = PatternNodeRewriter(
(add, 1, (neg, (erfc, "x"))), (add, 1, (neg, (erfc, "x"))),
(erf, "x"), (erf, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2594,7 +2594,7 @@ register_stabilize(local_one_minus_erfc2) ...@@ -2594,7 +2594,7 @@ register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2) register_specialize(local_one_minus_erfc2)
# (-1)+erfc(-x)=>erf(x) # (-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = PatternSub( local_erf_neg_minus_one = PatternNodeRewriter(
(add, -1, (erfc, (neg, "x"))), (add, -1, (erfc, (neg, "x"))),
(erf, "x"), (erf, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2914,7 +2914,7 @@ def _is_1(expr): ...@@ -2914,7 +2914,7 @@ def _is_1(expr):
return False return False
logsigm_to_softplus = PatternSub( logsigm_to_softplus = PatternNodeRewriter(
(log, (sigmoid, "x")), (log, (sigmoid, "x")),
(neg, (softplus, (neg, "x"))), (neg, (softplus, (neg, "x"))),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2923,7 +2923,7 @@ logsigm_to_softplus = PatternSub( ...@@ -2923,7 +2923,7 @@ logsigm_to_softplus = PatternSub(
tracks=[sigmoid], tracks=[sigmoid],
get_nodes=get_clients_at_depth1, get_nodes=get_clients_at_depth1,
) )
log1msigm_to_softplus = PatternSub( log1msigm_to_softplus = PatternNodeRewriter(
(log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))), (log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))),
(neg, (softplus, "x")), (neg, (softplus, "x")),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -2932,13 +2932,13 @@ log1msigm_to_softplus = PatternSub( ...@@ -2932,13 +2932,13 @@ log1msigm_to_softplus = PatternSub(
tracks=[sigmoid], tracks=[sigmoid],
get_nodes=get_clients_at_depth2, get_nodes=get_clients_at_depth2,
) )
log1pexp_to_softplus = PatternSub( log1pexp_to_softplus = PatternNodeRewriter(
(log1p, (exp, "x")), (log1p, (exp, "x")),
(softplus, "x"), (softplus, "x"),
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True, allow_multiple_clients=True,
) )
log1p_neg_sigmoid = PatternSub( log1p_neg_sigmoid = PatternNodeRewriter(
(log1p, (neg, (sigmoid, "x"))), (log1p, (neg, (sigmoid, "x"))),
(neg, (softplus, "x")), (neg, (softplus, "x")),
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
...@@ -3511,7 +3511,7 @@ def local_reciprocal_1_plus_exp(fgraph, node): ...@@ -3511,7 +3511,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
# 1 - sigmoid(x) -> sigmoid(-x) # 1 - sigmoid(x) -> sigmoid(-x)
local_1msigmoid = PatternSub( local_1msigmoid = PatternNodeRewriter(
(sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")), (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")),
(sigmoid, (neg, "x")), (sigmoid, (neg, "x")),
tracks=[sigmoid], tracks=[sigmoid],
...@@ -3522,7 +3522,7 @@ register_stabilize(local_1msigmoid) ...@@ -3522,7 +3522,7 @@ register_stabilize(local_1msigmoid)
register_specialize(local_1msigmoid) register_specialize(local_1msigmoid)
log1pmexp_to_log1mexp = PatternSub( log1pmexp_to_log1mexp = PatternNodeRewriter(
(log1p, (neg, (exp, "x"))), (log1p, (neg, (exp, "x"))),
(log1mexp, "x"), (log1mexp, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -3532,7 +3532,7 @@ register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp") ...@@ -3532,7 +3532,7 @@ register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")
# log(sigmoid(x) / (1 - sigmoid(x))) -> x # log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x # i.e logit(sigmoid(x)) -> x
local_logit_sigmoid = PatternSub( local_logit_sigmoid = PatternNodeRewriter(
(log, (true_div, (sigmoid, "x"), (sub, 1, (sigmoid, "x")))), (log, (true_div, (sigmoid, "x"), (sub, 1, (sigmoid, "x")))),
"x", "x",
tracks=[sigmoid], tracks=[sigmoid],
...@@ -3546,7 +3546,7 @@ register_specialize(local_logit_sigmoid) ...@@ -3546,7 +3546,7 @@ register_specialize(local_logit_sigmoid)
# sigmoid(log(x / (1-x)) -> x # sigmoid(log(x / (1-x)) -> x
# i.e., sigmoid(logit(x)) -> x # i.e., sigmoid(logit(x)) -> x
local_sigmoid_logit = PatternSub( local_sigmoid_logit = PatternNodeRewriter(
(sigmoid, (log, (true_div, "x", (sub, 1, "x")))), (sigmoid, (log, (true_div, "x", (sub, 1, "x")))),
"x", "x",
allow_multiple_clients=True, allow_multiple_clients=True,
......
...@@ -89,7 +89,7 @@ For starters, let's define the following simplification: ...@@ -89,7 +89,7 @@ For starters, let's define the following simplification:
\frac{xy}{y} = x \frac{xy}{y} = x
We will implement it in three ways: using a global optimization, a We will implement it in three ways: using a global optimization, a
local optimization with a :class:`NavigatorOptimizer` and then using the :class:`PatternSub` local optimization with a :class:`NavigatorOptimizer` and then using the :class:`PatternNodeRewriter`
facility. facility.
Global optimization Global optimization
...@@ -270,8 +270,8 @@ FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) ...@@ -270,8 +270,8 @@ FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x))))
>>> e >>> e
FunctionGraph(add(z, mul(x, true_div(z, x)))) FunctionGraph(add(z, mul(x, true_div(z, x))))
:class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter`, :class:`PatternSub` :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter`, :class:`PatternNodeRewriter`
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aesara defines some shortcuts to make :class:`NodeRewriter`\s: Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
...@@ -288,15 +288,15 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s: ...@@ -288,15 +288,15 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
outputs as it has inputs. The first output becomes the first input, outputs as it has inputs. The first output becomes the first input,
the second output becomes the second input, and so on. the second output becomes the second input, and so on.
.. function:: PatternSub(pattern1, pattern2) .. function:: PatternNodeRewriter(pattern1, pattern2)
Replaces all occurrences of the first pattern by the second pattern. Replaces all occurrences of the first pattern by the second pattern.
See :class:`PatternSub`. See :class:`PatternNodeRewriter`.
.. code:: .. code::
from aesara.scalar import identity from aesara.scalar import identity
from aesara.graph.opt import SubstitutionNodeRewriter, RemovalNodeRewriter, PatternSub from aesara.graph.opt import SubstitutionNodeRewriter, RemovalNodeRewriter, PatternNodeRewriter
# Replacing `add` by `mul` (this is not recommended for primarily # Replacing `add` by `mul` (this is not recommended for primarily
# mathematical reasons): # mathematical reasons):
...@@ -308,25 +308,25 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s: ...@@ -308,25 +308,25 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
# The "simplify" operation we've been defining in the past few # The "simplify" operation we've been defining in the past few
# sections. Note that we need two patterns to account for the # sections. Note that we need two patterns to account for the
# permutations of the arguments to `mul`. # permutations of the arguments to `mul`.
local_simplify_1 = PatternSub((true_div, (mul, 'x', 'y'), 'y'), 'x') local_simplify_1 = PatternNodeRewriter((true_div, (mul, 'x', 'y'), 'y'), 'x')
local_simplify_2 = PatternSub((true_div, (mul, 'x', 'y'), 'x'), 'y') local_simplify_2 = PatternNodeRewriter((true_div, (mul, 'x', 'y'), 'x'), 'y')
.. note:: .. note::
:class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` and :class:`PatternSub` produce local optimizers, which :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` and :class:`PatternNodeRewriter` produce local optimizers, which
means that everything we said previously about local optimizers means that everything we said previously about local optimizers
apply (e.g. they need to be wrapped in a :class:`NavigatorOptimizer`, etc.) apply (e.g. they need to be wrapped in a :class:`NavigatorOptimizer`, etc.)
When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter`
or :class:`PatternSub`, it is highly recommended to use them. or :class:`PatternNodeRewriter`, it is highly recommended to use them.
.. _unification: .. _unification:
Unification and reification Unification and reification
=========================== ===========================
The :class:`PatternSub` class uses `unification and reification The :class:`PatternNodeRewriter` class uses `unification and reification
<https://en.wikipedia.org/wiki/Unification_(computer_science)>`_ to implement a <https://en.wikipedia.org/wiki/Unification_(computer_science)>`_ to implement a
more succinct and reusable form of "pattern matching and replacement". more succinct and reusable form of "pattern matching and replacement".
In general, *use of the unification and reification tools is preferable when In general, *use of the unification and reification tools is preferable when
...@@ -345,7 +345,7 @@ In order to use :func:`unify` and :func:`reify` with Aesara graphs, we need an i ...@@ -345,7 +345,7 @@ In order to use :func:`unify` and :func:`reify` with Aesara graphs, we need an i
structure that will allow us to represent Aesara graphs that contain :class:`var`\s, because structure that will allow us to represent Aesara graphs that contain :class:`var`\s, because
Aesara :class:`Op`\s and :class:`Apply` nodes will not accept these foreign objects as inputs. Aesara :class:`Op`\s and :class:`Apply` nodes will not accept these foreign objects as inputs.
:class:`PatternSub` uses Python ``tuple``\s to effectively represent :class:`Apply` nodes and :class:`PatternNodeRewriter` uses Python ``tuple``\s to effectively represent :class:`Apply` nodes and
``str``\s to represent logic variables (i.e. :class:`var`\s in the :mod:`unification` library). ``str``\s to represent logic variables (i.e. :class:`var`\s in the :mod:`unification` library).
Behind the scenes, these ``tuple``\s are converted to a ``tuple`` subclass called :class:`ExpressionTuple`\s, Behind the scenes, these ``tuple``\s are converted to a ``tuple`` subclass called :class:`ExpressionTuple`\s,
which behave just like normal ``tuple``\s except for some special caching features that allow for easy which behave just like normal ``tuple``\s except for some special caching features that allow for easy
......
...@@ -13,7 +13,7 @@ from aesara.compile.io import In, Out ...@@ -13,7 +13,7 @@ from aesara.compile.io import In, Out
from aesara.compile.mode import Mode, get_default_mode from aesara.compile.mode import Mode, get_default_mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.opt import OpKeyOptimizer, PatternSub from aesara.graph.opt import OpKeyOptimizer, PatternNodeRewriter
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.link.vm import VMLinker from aesara.link.vm import VMLinker
from aesara.tensor.math import dot from aesara.tensor.math import dot
...@@ -35,7 +35,7 @@ from aesara.utils import exc_message ...@@ -35,7 +35,7 @@ from aesara.utils import exc_message
def PatternOptimizer(p1, p2, ign=True): def PatternOptimizer(p1, p2, ign=True):
return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) return OpKeyOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestFunction: class TestFunction:
......
...@@ -11,7 +11,7 @@ from aesara.graph.op import Op ...@@ -11,7 +11,7 @@ from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
NavigatorOptimizer, NavigatorOptimizer,
OpKeyOptimizer, OpKeyOptimizer,
PatternSub, PatternNodeRewriter,
SubstitutionNodeRewriter, SubstitutionNodeRewriter,
TopoOptimizer, TopoOptimizer,
) )
...@@ -21,7 +21,7 @@ from tests.unittest_tools import assertFailure_fast ...@@ -21,7 +21,7 @@ from tests.unittest_tools import assertFailure_fast
def PatternOptimizer(p1, p2, ign=True): def PatternOptimizer(p1, p2, ign=True):
return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) return OpKeyOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def TopoSubstitutionNodeRewriter( def TopoSubstitutionNodeRewriter(
......
...@@ -10,7 +10,7 @@ from aesara.graph.opt import ( ...@@ -10,7 +10,7 @@ from aesara.graph.opt import (
MergeOptimizer, MergeOptimizer,
OpKeyOptimizer, OpKeyOptimizer,
OpToRewriterTracker, OpToRewriterTracker,
PatternSub, PatternNodeRewriter,
SequentialNodeRewriter, SequentialNodeRewriter,
SubstitutionNodeRewriter, SubstitutionNodeRewriter,
TopoOptimizer, TopoOptimizer,
...@@ -51,11 +51,11 @@ class AssertNoChanges(Feature): ...@@ -51,11 +51,11 @@ class AssertNoChanges(Feature):
def PatternOptimizer(p1, p2, ign=False): def PatternOptimizer(p1, p2, ign=False):
return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) return OpKeyOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def TopoPatternOptimizer(p1, p2, ign=True): def TopoPatternOptimizer(p1, p2, ign=True):
return TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) return TopoOptimizer(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestPatternOptimizer: class TestPatternOptimizer:
...@@ -448,9 +448,9 @@ class TestEquilibrium: ...@@ -448,9 +448,9 @@ class TestEquilibrium:
# print g # print g
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[ [
PatternSub((op1, "x", "y"), (op2, "x", "y")), PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")),
PatternSub((op4, "x", "y"), (op1, "x", "y")), PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")),
PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")), PatternNodeRewriter((op3, (op2, "x", "y")), (op4, "x", "y")),
], ],
max_use_ratio=10, max_use_ratio=10,
) )
...@@ -465,11 +465,11 @@ class TestEquilibrium: ...@@ -465,11 +465,11 @@ class TestEquilibrium:
# print g # print g
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[ [
PatternSub((op1, (op2, "x", "y")), (op4, "x", "y")), PatternNodeRewriter((op1, (op2, "x", "y")), (op4, "x", "y")),
PatternSub((op3, "x", "y"), (op4, "x", "y")), PatternNodeRewriter((op3, "x", "y"), (op4, "x", "y")),
PatternSub((op4, "x", "y"), (op5, "x", "y")), PatternNodeRewriter((op4, "x", "y"), (op5, "x", "y")),
PatternSub((op5, "x", "y"), (op6, "x", "y")), PatternNodeRewriter((op5, "x", "y"), (op6, "x", "y")),
PatternSub((op6, "x", "y"), (op2, "x", "y")), PatternNodeRewriter((op6, "x", "y"), (op2, "x", "y")),
], ],
max_use_ratio=10, max_use_ratio=10,
) )
...@@ -490,9 +490,9 @@ class TestEquilibrium: ...@@ -490,9 +490,9 @@ class TestEquilibrium:
try: try:
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[ [
PatternSub((op1, "x", "y"), (op2, "x", "y")), PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")),
PatternSub((op4, "x", "y"), (op1, "x", "y")), PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")),
PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")), PatternNodeRewriter((op3, (op2, "x", "y")), (op4, "x", "y")),
], ],
max_use_ratio=1.0 / len(g.apply_nodes), max_use_ratio=1.0 / len(g.apply_nodes),
) # each opt can only be applied once ) # each opt can only be applied once
...@@ -595,14 +595,14 @@ def test_pre_greedy_node_rewriter(): ...@@ -595,14 +595,14 @@ def test_pre_greedy_node_rewriter():
@pytest.mark.parametrize("tracks", [True, False]) @pytest.mark.parametrize("tracks", [True, False])
@pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0]) @pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0])
def test_patternsub_values_eq_approx(out_pattern, tracks): def test_patternsub_values_eq_approx(out_pattern, tracks):
# PatternSub would fail when `values_eq_approx` and `get_nodes` were specified # PatternNodeRewriter would fail when `values_eq_approx` and `get_nodes` were specified
x = MyVariable("x") x = MyVariable("x")
e = op1(x) e = op1(x)
fg = FunctionGraph([x], [e], clone=False) fg = FunctionGraph([x], [e], clone=False)
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[ [
PatternSub( PatternNodeRewriter(
(op1, "x"), (op1, "x"),
out_pattern, out_pattern,
tracks=[op1] if tracks else (), tracks=[op1] if tracks else (),
...@@ -628,14 +628,14 @@ def test_patternsub_values_eq_approx(out_pattern, tracks): ...@@ -628,14 +628,14 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
@pytest.mark.parametrize("out_pattern", [(op1, "x"), "x"]) @pytest.mark.parametrize("out_pattern", [(op1, "x"), "x"])
def test_patternsub_invalid_dtype(out_pattern): def test_patternsub_invalid_dtype(out_pattern):
# PatternSub would wrongly return output of different dtype as the original node # PatternNodeRewriter would wrongly return output of different dtype as the original node
x = MyVariable("x") x = MyVariable("x")
e = op_cast_type2(x) e = op_cast_type2(x)
fg = FunctionGraph([x], [e]) fg = FunctionGraph([x], [e])
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[ [
PatternSub( PatternNodeRewriter(
(op_cast_type2, "x"), (op_cast_type2, "x"),
out_pattern, out_pattern,
) )
...@@ -647,8 +647,8 @@ def test_patternsub_invalid_dtype(out_pattern): ...@@ -647,8 +647,8 @@ def test_patternsub_invalid_dtype(out_pattern):
def test_patternsub_different_output_lengths(): def test_patternsub_different_output_lengths():
# Test that PatternSub won't replace nodes with different numbers of outputs # Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
ps = PatternSub( ps = PatternNodeRewriter(
(op1, "x"), (op1, "x"),
("x"), ("x"),
name="ps", name="ps",
......
...@@ -4312,7 +4312,7 @@ class TestSigmoidOpts: ...@@ -4312,7 +4312,7 @@ class TestSigmoidOpts:
# tests exp_over_1_plus_exp # tests exp_over_1_plus_exp
f = aesara.function([x], 1 - exp(x) / (1 + exp(x)), mode=m) f = aesara.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
# FIXME: PatternSub does not copy stack trace # FIXME: PatternNodeRewriter does not copy stack trace
# (see https://github.com/Theano/Theano/issues/4581) # (see https://github.com/Theano/Theano/issues/4581)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid]) # assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid] assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论