提交 d5013456 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename NavigatorOptimizer to NodeProcessingGraphRewriter

上级 6302cef1
......@@ -14,7 +14,7 @@ from aesara.graph.opt import (
CheckStackTraceOptimization,
GraphRewriter,
MergeOptimizer,
NavigatorOptimizer,
NodeProcessingGraphRewriter,
)
from aesara.graph.optdb import (
EquilibriumDB,
......@@ -193,7 +193,7 @@ optdb.register(
local_useless = LocalGroupDB(apply_all_opts=True, profile=True)
optdb.register(
"useless",
TopoDB(local_useless, failure_callback=NavigatorOptimizer.warn_inplace),
TopoDB(local_useless, failure_callback=NodeProcessingGraphRewriter.warn_inplace),
"fast_run",
"fast_compile",
position=0.6,
......
......@@ -48,7 +48,7 @@ _logger = logging.getLogger("aesara.graph.opt")
FailureCallbackType = Callable[
[
Exception,
"NavigatorOptimizer",
"NodeProcessingGraphRewriter",
List[Tuple[Variable, None]],
"NodeRewriter",
Apply,
......@@ -1210,7 +1210,7 @@ class SequentialNodeRewriter(NodeRewriter):
Attributes
----------
reentrant : bool
Some global optimizers, like `NavigatorOptimizer`, use this value to
Some global optimizers, like `NodeProcessingGraphRewriter`, use this value to
determine if they should ignore new nodes.
retains_inputs : bool
States whether or not the inputs of a transformed node are transferred
......@@ -1724,13 +1724,17 @@ class Updater(Feature):
self.chin = None
class NavigatorOptimizer(GraphRewriter):
r"""An optimizer that applies a `NodeRewriter` with considerations for the new nodes it creates.
class NodeProcessingGraphRewriter(GraphRewriter):
r"""A rewriter that applies a `NodeRewriter` with considerations for the new nodes it creates.
The results of successful rewrites are considered for rewriting based on
the values of `NodeProcessingGraphRewriter.ignore_newtrees` and/or
`NodeRewriter.reentrant`.
This optimizer also allows the `NodeRewriter` to use a special ``"remove"`` value
in the ``dict``\s returned by :meth:`NodeRewriter`. `Variable`\s mapped to this
value are removed from the `FunctionGraph`.
This rewriter accepts ``dict`` values from `NodeRewriter.transform`.
Entries in these ``dict``\s can be `Variable`\s and their new values.
It also accepts a special ``"remove"`` key. A sequence of `Variable`\s
mapped to the key ``"remove"`` are removed from the `FunctionGraph`.
"""
......@@ -1759,7 +1763,9 @@ class NavigatorOptimizer(GraphRewriter):
"""
if isinstance(exc, InconsistencyError):
return
return NavigatorOptimizer.warn(exc, nav, repl_pairs, node_rewriter, node)
return NodeProcessingGraphRewriter.warn(
exc, nav, repl_pairs, node_rewriter, node
)
@staticmethod
def warn_ignore(exc, nav, repl_pairs, node_rewriter, node):
......@@ -1778,10 +1784,10 @@ class NavigatorOptimizer(GraphRewriter):
node_rewriter
A `NodeRewriter` to apply over a `FunctionGraph` (or ``None``).
ignore_newtrees
- ``True``: new subgraphs returned by an optimization are not a
candidate for optimization.
- ``False``: new subgraphs returned by an optimization is a
candidate for optimization.
- ``True``: new subgraphs returned by an `NodeRewriter` are not a
candidate for rewriting.
- ``False``: new subgraphs returned by an `NodeRewriter` is a
candidate for rewriting.
- ``'auto'``: let the `node_rewriter` set this parameter via its
:attr:`reentrant` attribute.
failure_callback
......@@ -1970,7 +1976,7 @@ class NavigatorOptimizer(GraphRewriter):
)
class TopoOptimizer(NavigatorOptimizer):
class TopoOptimizer(NodeProcessingGraphRewriter):
"""An optimizer that applies a single `NodeRewriter` to each node in topological order (or reverse)."""
def __init__(
......@@ -2116,7 +2122,7 @@ in2out = partial(topogroup_optimizer, "in_to_out")
out2in = partial(topogroup_optimizer, "out_to_in")
class OpKeyOptimizer(NavigatorOptimizer):
class OpKeyOptimizer(NodeProcessingGraphRewriter):
r"""An optimizer that applies a `NodeRewriter` to specific `Op`\s.
The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
......@@ -2200,7 +2206,7 @@ def merge_dict(d1, d2):
return d
class EquilibriumOptimizer(NavigatorOptimizer):
class EquilibriumOptimizer(NodeProcessingGraphRewriter):
"""An `Rewriter` that applies an optimization until a fixed-point/equilibrium is reached."""
def __init__(
......@@ -2222,11 +2228,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
The global optimizer will be run at the start of each iteration before
the node rewriter.
failure_callback
See :attr:`NavigatorOptimizer.failure_callback`.
See :attr:`NodeProcessingGraphRewriter.failure_callback`.
ignore_newtrees
See :attr:`NavigatorOptimizer.ignore_newtrees`.
See :attr:`NodeProcessingGraphRewriter.ignore_newtrees`.
tracks_on_change_inputs
See :attr:`NavigatorOptimizer.tracks_on_change_inputs`.
See :attr:`NodeProcessingGraphRewriter.tracks_on_change_inputs`.
max_use_ratio
Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)``
times.
......@@ -3188,6 +3194,11 @@ DEPRECATED_NAMES = [
"`PatternSub` is deprecated: use `PatternNodeRewriter` instead.",
PatternNodeRewriter,
),
(
"NavigatorOptimizer",
"`NavigatorOptimizer` is deprecated: use `NodeProcessingGraphRewriter` instead.",
NodeProcessingGraphRewriter,
),
]
......
......@@ -346,7 +346,7 @@ class EquilibriumDB(OptimizationDatabase):
max_use_ratio=config.optdb__max_use_ratio,
ignore_newtrees=self.ignore_newtrees,
tracks_on_change_inputs=self.tracks_on_change_inputs,
failure_callback=aesara_opt.NavigatorOptimizer.warn_inplace,
failure_callback=aesara_opt.NodeProcessingGraphRewriter.warn_inplace,
final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts,
)
......
......@@ -74,7 +74,7 @@ A local optimization is an object which defines the following methods:
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
returns either ``False`` to signify that no changes are to be done or a
list of :class:`Variable`\s which matches the length of the node's ``outputs``
list. When the :class:`NodeRewriter` is applied by a :class:`NavigatorOptimizer`, the outputs
list. When the :class:`NodeRewriter` is applied by a :class:`NodeProcessingGraphRewriter`, the outputs
of the node passed as argument to the :class:`NodeRewriter` will be replaced by
the list returned.
......@@ -89,7 +89,7 @@ For starters, let's define the following simplification:
\frac{xy}{y} = x
We will implement it in three ways: using a global optimization, a
local optimization with a :class:`NavigatorOptimizer` and then using the :class:`PatternNodeRewriter`
local optimization with a :class:`NodeProcessingGraphRewriter` and then using the :class:`PatternNodeRewriter`
facility.
Global optimization
......@@ -253,7 +253,7 @@ outputs are returned. This list must have the same length as
you can put ``None`` in the returned list to remove it.
In order to apply the local optimizer we can use it in conjunction
with a :class:`NavigatorOptimizer`. Basically, a :class:`NavigatorOptimizer` is
with a :class:`NodeProcessingGraphRewriter`. Basically, a :class:`NodeProcessingGraphRewriter` is
a global optimizer that loops through all nodes in the graph (or a well-defined
subset of them) and applies one or several local optimizers.
......@@ -315,7 +315,7 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
:class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` and :class:`PatternNodeRewriter` produce local optimizers, which
means that everything we said previously about local optimizers
apply (e.g. they need to be wrapped in a :class:`NavigatorOptimizer`, etc.)
apply (e.g. they need to be wrapped in a :class:`NodeProcessingGraphRewriter`, etc.)
When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter`
......@@ -702,7 +702,7 @@ Registering a :class:`NodeRewriter`
:class:`NodeRewriter`\s may be registered in two ways:
* Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer
* Wrap them in a :class:`NodeProcessingGraphRewriter` and insert them like a global optimizer
(see previous section).
* Put them in an :class:`EquilibriumDB`.
......@@ -795,8 +795,8 @@ under the assumption there are no inplace operations.
.. _navigator:
:class:`NavigatorOptimizer`
---------------------------
:class:`NodeProcessingGraphRewriter`
------------------------------------
WRITEME
......
......@@ -9,7 +9,7 @@ from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import (
NavigatorOptimizer,
NodeProcessingGraphRewriter,
OpKeyOptimizer,
PatternNodeRewriter,
SubstitutionNodeRewriter,
......@@ -25,7 +25,7 @@ def PatternOptimizer(p1, p2, ign=True):
def TopoSubstitutionNodeRewriter(
op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True
op1, op2, fail=NodeProcessingGraphRewriter.warn_ignore, ign=True
):
return TopoOptimizer(
SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail
......
......@@ -150,7 +150,7 @@ class TestPatternOptimizer:
def test_ambiguous(self):
# this test should always work with TopoOptimizer and the
# ignore_newtrees flag set to False. Behavior with ignore_newtrees
# = True or with other NavigatorOptimizers may differ.
# = True or with other NodeProcessingGraphRewriters may differ.
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论