提交 d5013456 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename NavigatorOptimizer to NodeProcessingGraphRewriter

上级 6302cef1
...@@ -14,7 +14,7 @@ from aesara.graph.opt import ( ...@@ -14,7 +14,7 @@ from aesara.graph.opt import (
CheckStackTraceOptimization, CheckStackTraceOptimization,
GraphRewriter, GraphRewriter,
MergeOptimizer, MergeOptimizer,
NavigatorOptimizer, NodeProcessingGraphRewriter,
) )
from aesara.graph.optdb import ( from aesara.graph.optdb import (
EquilibriumDB, EquilibriumDB,
...@@ -193,7 +193,7 @@ optdb.register( ...@@ -193,7 +193,7 @@ optdb.register(
local_useless = LocalGroupDB(apply_all_opts=True, profile=True) local_useless = LocalGroupDB(apply_all_opts=True, profile=True)
optdb.register( optdb.register(
"useless", "useless",
TopoDB(local_useless, failure_callback=NavigatorOptimizer.warn_inplace), TopoDB(local_useless, failure_callback=NodeProcessingGraphRewriter.warn_inplace),
"fast_run", "fast_run",
"fast_compile", "fast_compile",
position=0.6, position=0.6,
......
...@@ -48,7 +48,7 @@ _logger = logging.getLogger("aesara.graph.opt") ...@@ -48,7 +48,7 @@ _logger = logging.getLogger("aesara.graph.opt")
FailureCallbackType = Callable[ FailureCallbackType = Callable[
[ [
Exception, Exception,
"NavigatorOptimizer", "NodeProcessingGraphRewriter",
List[Tuple[Variable, None]], List[Tuple[Variable, None]],
"NodeRewriter", "NodeRewriter",
Apply, Apply,
...@@ -1210,7 +1210,7 @@ class SequentialNodeRewriter(NodeRewriter): ...@@ -1210,7 +1210,7 @@ class SequentialNodeRewriter(NodeRewriter):
Attributes Attributes
---------- ----------
reentrant : bool reentrant : bool
Some global optimizers, like `NavigatorOptimizer`, use this value to Some global optimizers, like `NodeProcessingGraphRewriter`, use this value to
determine if they should ignore new nodes. determine if they should ignore new nodes.
retains_inputs : bool retains_inputs : bool
States whether or not the inputs of a transformed node are transferred States whether or not the inputs of a transformed node are transferred
...@@ -1724,13 +1724,17 @@ class Updater(Feature): ...@@ -1724,13 +1724,17 @@ class Updater(Feature):
self.chin = None self.chin = None
class NavigatorOptimizer(GraphRewriter): class NodeProcessingGraphRewriter(GraphRewriter):
r"""An optimizer that applies a `NodeRewriter` with considerations for the new nodes it creates. r"""A rewriter that applies a `NodeRewriter` with considerations for the new nodes it creates.
The results of successful rewrites are considered for rewriting based on
the values of `NodeProcessingGraphRewriter.ignore_newtrees` and/or
`NodeRewriter.reentrant`.
This optimizer also allows the `NodeRewriter` to use a special ``"remove"`` value This rewriter accepts ``dict`` values from `NodeRewriter.transform`.
in the ``dict``\s returned by :meth:`NodeRewriter`. `Variable`\s mapped to this Entries in these ``dict``\s can be `Variable`\s and their new values.
value are removed from the `FunctionGraph`. It also accepts a special ``"remove"`` key. A sequence of `Variable`\s
mapped to the key ``"remove"`` are removed from the `FunctionGraph`.
""" """
...@@ -1759,7 +1763,9 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1759,7 +1763,9 @@ class NavigatorOptimizer(GraphRewriter):
""" """
if isinstance(exc, InconsistencyError): if isinstance(exc, InconsistencyError):
return return
return NavigatorOptimizer.warn(exc, nav, repl_pairs, node_rewriter, node) return NodeProcessingGraphRewriter.warn(
exc, nav, repl_pairs, node_rewriter, node
)
@staticmethod @staticmethod
def warn_ignore(exc, nav, repl_pairs, node_rewriter, node): def warn_ignore(exc, nav, repl_pairs, node_rewriter, node):
...@@ -1778,10 +1784,10 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1778,10 +1784,10 @@ class NavigatorOptimizer(GraphRewriter):
node_rewriter node_rewriter
A `NodeRewriter` to apply over a `FunctionGraph` (or ``None``). A `NodeRewriter` to apply over a `FunctionGraph` (or ``None``).
ignore_newtrees ignore_newtrees
- ``True``: new subgraphs returned by an optimization are not a - ``True``: new subgraphs returned by an `NodeRewriter` are not a
candidate for optimization. candidate for rewriting.
- ``False``: new subgraphs returned by an optimization is a - ``False``: new subgraphs returned by an `NodeRewriter` is a
candidate for optimization. candidate for rewriting.
- ``'auto'``: let the `node_rewriter` set this parameter via its - ``'auto'``: let the `node_rewriter` set this parameter via its
:attr:`reentrant` attribute. :attr:`reentrant` attribute.
failure_callback failure_callback
...@@ -1970,7 +1976,7 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1970,7 +1976,7 @@ class NavigatorOptimizer(GraphRewriter):
) )
class TopoOptimizer(NavigatorOptimizer): class TopoOptimizer(NodeProcessingGraphRewriter):
"""An optimizer that applies a single `NodeRewriter` to each node in topological order (or reverse).""" """An optimizer that applies a single `NodeRewriter` to each node in topological order (or reverse)."""
def __init__( def __init__(
...@@ -2116,7 +2122,7 @@ in2out = partial(topogroup_optimizer, "in_to_out") ...@@ -2116,7 +2122,7 @@ in2out = partial(topogroup_optimizer, "in_to_out")
out2in = partial(topogroup_optimizer, "out_to_in") out2in = partial(topogroup_optimizer, "out_to_in")
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NodeProcessingGraphRewriter):
r"""An optimizer that applies a `NodeRewriter` to specific `Op`\s. r"""An optimizer that applies a `NodeRewriter` to specific `Op`\s.
The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
...@@ -2200,7 +2206,7 @@ def merge_dict(d1, d2): ...@@ -2200,7 +2206,7 @@ def merge_dict(d1, d2):
return d return d
class EquilibriumOptimizer(NavigatorOptimizer): class EquilibriumOptimizer(NodeProcessingGraphRewriter):
"""An `Rewriter` that applies an optimization until a fixed-point/equilibrium is reached.""" """An `Rewriter` that applies an optimization until a fixed-point/equilibrium is reached."""
def __init__( def __init__(
...@@ -2222,11 +2228,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2222,11 +2228,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
The global optimizer will be run at the start of each iteration before The global optimizer will be run at the start of each iteration before
the node rewriter. the node rewriter.
failure_callback failure_callback
See :attr:`NavigatorOptimizer.failure_callback`. See :attr:`NodeProcessingGraphRewriter.failure_callback`.
ignore_newtrees ignore_newtrees
See :attr:`NavigatorOptimizer.ignore_newtrees`. See :attr:`NodeProcessingGraphRewriter.ignore_newtrees`.
tracks_on_change_inputs tracks_on_change_inputs
See :attr:`NavigatorOptimizer.tracks_on_change_inputs`. See :attr:`NodeProcessingGraphRewriter.tracks_on_change_inputs`.
max_use_ratio max_use_ratio
Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)`` Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)``
times. times.
...@@ -3188,6 +3194,11 @@ DEPRECATED_NAMES = [ ...@@ -3188,6 +3194,11 @@ DEPRECATED_NAMES = [
"`PatternSub` is deprecated: use `PatternNodeRewriter` instead.", "`PatternSub` is deprecated: use `PatternNodeRewriter` instead.",
PatternNodeRewriter, PatternNodeRewriter,
), ),
(
"NavigatorOptimizer",
"`NavigatorOptimizer` is deprecated: use `NodeProcessingGraphRewriter` instead.",
NodeProcessingGraphRewriter,
),
] ]
......
...@@ -346,7 +346,7 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -346,7 +346,7 @@ class EquilibriumDB(OptimizationDatabase):
max_use_ratio=config.optdb__max_use_ratio, max_use_ratio=config.optdb__max_use_ratio,
ignore_newtrees=self.ignore_newtrees, ignore_newtrees=self.ignore_newtrees,
tracks_on_change_inputs=self.tracks_on_change_inputs, tracks_on_change_inputs=self.tracks_on_change_inputs,
failure_callback=aesara_opt.NavigatorOptimizer.warn_inplace, failure_callback=aesara_opt.NodeProcessingGraphRewriter.warn_inplace,
final_optimizers=final_opts, final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts, cleanup_optimizers=cleanup_opts,
) )
......
...@@ -74,7 +74,7 @@ A local optimization is an object which defines the following methods: ...@@ -74,7 +74,7 @@ A local optimization is an object which defines the following methods:
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
returns either ``False`` to signify that no changes are to be done or a returns either ``False`` to signify that no changes are to be done or a
list of :class:`Variable`\s which matches the length of the node's ``outputs`` list of :class:`Variable`\s which matches the length of the node's ``outputs``
list. When the :class:`NodeRewriter` is applied by a :class:`NavigatorOptimizer`, the outputs list. When the :class:`NodeRewriter` is applied by a :class:`NodeProcessingGraphRewriter`, the outputs
of the node passed as argument to the :class:`NodeRewriter` will be replaced by of the node passed as argument to the :class:`NodeRewriter` will be replaced by
the list returned. the list returned.
...@@ -89,7 +89,7 @@ For starters, let's define the following simplification: ...@@ -89,7 +89,7 @@ For starters, let's define the following simplification:
\frac{xy}{y} = x \frac{xy}{y} = x
We will implement it in three ways: using a global optimization, a We will implement it in three ways: using a global optimization, a
local optimization with a :class:`NavigatorOptimizer` and then using the :class:`PatternNodeRewriter` local optimization with a :class:`NodeProcessingGraphRewriter` and then using the :class:`PatternNodeRewriter`
facility. facility.
Global optimization Global optimization
...@@ -253,7 +253,7 @@ outputs are returned. This list must have the same length as ...@@ -253,7 +253,7 @@ outputs are returned. This list must have the same length as
you can put ``None`` in the returned list to remove it. you can put ``None`` in the returned list to remove it.
In order to apply the local optimizer we can use it in conjunction In order to apply the local optimizer we can use it in conjunction
with a :class:`NavigatorOptimizer`. Basically, a :class:`NavigatorOptimizer` is with a :class:`NodeProcessingGraphRewriter`. Basically, a :class:`NodeProcessingGraphRewriter` is
a global optimizer that loops through all nodes in the graph (or a well-defined a global optimizer that loops through all nodes in the graph (or a well-defined
subset of them) and applies one or several local optimizers. subset of them) and applies one or several local optimizers.
...@@ -315,7 +315,7 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s: ...@@ -315,7 +315,7 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
:class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` and :class:`PatternNodeRewriter` produce local optimizers, which :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` and :class:`PatternNodeRewriter` produce local optimizers, which
means that everything we said previously about local optimizers means that everything we said previously about local optimizers
apply (e.g. they need to be wrapped in a :class:`NavigatorOptimizer`, etc.) apply (e.g. they need to be wrapped in a :class:`NodeProcessingGraphRewriter`, etc.)
When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter`
...@@ -702,7 +702,7 @@ Registering a :class:`NodeRewriter` ...@@ -702,7 +702,7 @@ Registering a :class:`NodeRewriter`
:class:`NodeRewriter`\s may be registered in two ways: :class:`NodeRewriter`\s may be registered in two ways:
* Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer * Wrap them in a :class:`NodeProcessingGraphRewriter` and insert them like a global optimizer
(see previous section). (see previous section).
* Put them in an :class:`EquilibriumDB`. * Put them in an :class:`EquilibriumDB`.
...@@ -795,8 +795,8 @@ under the assumption there are no inplace operations. ...@@ -795,8 +795,8 @@ under the assumption there are no inplace operations.
.. _navigator: .. _navigator:
:class:`NavigatorOptimizer` :class:`NodeProcessingGraphRewriter`
--------------------------- ------------------------------------
WRITEME WRITEME
......
...@@ -9,7 +9,7 @@ from aesara.graph.features import ReplaceValidate ...@@ -9,7 +9,7 @@ from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
NavigatorOptimizer, NodeProcessingGraphRewriter,
OpKeyOptimizer, OpKeyOptimizer,
PatternNodeRewriter, PatternNodeRewriter,
SubstitutionNodeRewriter, SubstitutionNodeRewriter,
...@@ -25,7 +25,7 @@ def PatternOptimizer(p1, p2, ign=True): ...@@ -25,7 +25,7 @@ def PatternOptimizer(p1, p2, ign=True):
def TopoSubstitutionNodeRewriter( def TopoSubstitutionNodeRewriter(
op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True op1, op2, fail=NodeProcessingGraphRewriter.warn_ignore, ign=True
): ):
return TopoOptimizer( return TopoOptimizer(
SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail
......
...@@ -150,7 +150,7 @@ class TestPatternOptimizer: ...@@ -150,7 +150,7 @@ class TestPatternOptimizer:
def test_ambiguous(self): def test_ambiguous(self):
# this test should always work with TopoOptimizer and the # this test should always work with TopoOptimizer and the
# ignore_newtrees flag set to False. Behavior with ignore_newtrees # ignore_newtrees flag set to False. Behavior with ignore_newtrees
# = True or with other NavigatorOptimizers may differ. # = True or with other NodeProcessingGraphRewriters may differ.
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论