提交 fbe8c876 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename Updater to DispatchingFeature and update related type hints and docstrings

上级 d5013456
......@@ -909,7 +909,9 @@ class FunctionGraph(MetaObject):
for feature in self._features:
for attr in getattr(feature, "pickle_rm_attr", []):
del d[attr]
# The class Updater take fct as parameter and they are lambda function, so unpicklable.
# XXX: The `Feature` `DispatchingFeature` takes functions as parameter
# and they can be lambda functions, making them unpicklable.
# execute_callbacks_times have reference to optimizer, and they can't
# be pickled as the decorators with parameters aren't pickable.
......
......@@ -45,6 +45,12 @@ from aesara.utils import flatten
_logger = logging.getLogger("aesara.graph.opt")
RemoveKeyType = Literal["remove"]
TransformOutputType = Union[
bool,
Sequence[Variable],
Dict[Union[Variable, Literal["remove"]], Union[Variable, Sequence[Variable]]],
]
FailureCallbackType = Callable[
[
Exception,
......@@ -158,11 +164,7 @@ class NodeRewriter(Rewriter):
@abc.abstractmethod
def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
) -> Union[
bool,
Sequence[Variable],
Dict[Union[Variable, Literal["remove"]], Union[Variable, Sequence[Variable]]],
]:
) -> TransformOutputType:
r"""Rewrite the sub-graph given by `node`.
Subclasses should implement this function so that it returns one of the
......@@ -965,9 +967,13 @@ class MetaNodeRewriter(NodeRewriter):
def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]):
self.rewriters.append(rewriter)
for c in rewriter.tracks():
tracks = rewriter.tracks()
if tracks:
for c in tracks:
self.track_dict[c].append(rewriter)
self._tracks.append(c)
for tag in tag_list:
self.tag_dict[tag].append(rewriter)
......@@ -1695,7 +1701,9 @@ class PatternNodeRewriter(NodeRewriter):
)
class Updater(Feature):
class DispatchingFeature(Feature):
"""A `Feature` consisting of user-defined functions implementing each `Feature` callback method."""
def __init__(self, importer, pruner, chin, name=None):
self.importer = importer
self.pruner = pruner
......@@ -1703,7 +1711,7 @@ class Updater(Feature):
self.name = name
def __str__(self):
return f"Updater{{{self.name}}}"
return f"{type(self).__name__}{{{self.name}}}"
def on_import(self, fgraph, node, reason):
if self.importer:
......@@ -1725,24 +1733,30 @@ class Updater(Feature):
class NodeProcessingGraphRewriter(GraphRewriter):
r"""A rewriter that applies a `NodeRewriter` with considerations for the new nodes it creates.
r"""A class providing a base implementation for applying `NodeRewriter.transform` results to a graph.
The results of successful rewrites are considered for rewriting based on
the values of `NodeProcessingGraphRewriter.ignore_newtrees` and/or
`NodeRewriter.reentrant`.
This rewriter accepts the output of `NodeRewriter.transform`
implementations and applies them to a `FunctionGraph`.
This rewriter accepts ``dict`` values from `NodeRewriter.transform`.
Entries in these ``dict``\s can be `Variable`\s and their new values.
It also accepts a special ``"remove"`` key. A sequence of `Variable`\s
mapped to the key ``"remove"`` are removed from the `FunctionGraph`.
It accepts a sequence of new output nodes or ``dict``s. Entries in
these ``dict``\s can be `Variable`\s and their new values. It also accepts
a special ``"remove"`` key. A sequence of `Variable`\s mapped to the key
``"remove"`` are removed from the `FunctionGraph`.
It also adds some interface elements for simple reentrant/recursive
application of rewrites. The parameter `NodeRewriter.ignore_newtrees` is
intended to be used by subclasses, alongside the
`NodeRewriter.attach_updater` and `NodeRewriter.detach_updater` methods, to
determine whether or not sub-graphs created by rewrites are to have the
same rewrites applied to them.
"""
@staticmethod
def warn(exc, nav, repl_pairs, node_rewriter, node):
@classmethod
def warn(cls, exc, nav, repl_pairs, node_rewriter, node):
"""A failure callback that prints a traceback."""
if config.on_opt_error != "ignore":
_logger.error(f"Optimization failure due to: {node_rewriter}")
_logger.error(f"Rewrite failure due to: {node_rewriter}")
_logger.error(f"node: {node}")
_logger.error("TRACEBACK:")
_logger.error(traceback.format_exc())
......@@ -1753,8 +1767,8 @@ class NodeProcessingGraphRewriter(GraphRewriter):
# seriously wrong if such an exception is raised.
raise exc
@staticmethod
def warn_inplace(exc, nav, repl_pairs, node_rewriter, node):
@classmethod
def warn_inplace(cls, exc, nav, repl_pairs, node_rewriter, node):
r"""A failure callback that ignores `InconsistencyError`\s and prints a traceback.
If the error occurred during replacement, `repl_pairs` is set;
......@@ -1763,12 +1777,10 @@ class NodeProcessingGraphRewriter(GraphRewriter):
"""
if isinstance(exc, InconsistencyError):
return
return NodeProcessingGraphRewriter.warn(
exc, nav, repl_pairs, node_rewriter, node
)
return cls.warn(exc, nav, repl_pairs, node_rewriter, node)
@staticmethod
def warn_ignore(exc, nav, repl_pairs, node_rewriter, node):
@classmethod
def warn_ignore(cls, exc, nav, repl_pairs, node_rewriter, node):
"""A failure callback that ignores all errors."""
def __init__(
......@@ -1812,22 +1824,29 @@ class NodeProcessingGraphRewriter(GraphRewriter):
else:
self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback
super().__init__()
def attach_updater(self, fgraph, importer, pruner, chin=None, name=None):
r"""Install `FunctionGraph` listeners to help the navigator deal with the ``ignore_trees``-related functionality.
def attach_updater(
self,
fgraph: FunctionGraph,
importer: Optional[Callable],
pruner: Optional[Callable],
chin: Optional[Callable] = None,
name: Optional[str] = None,
) -> Optional[DispatchingFeature]:
r"""Install `FunctionGraph` listeners to help the navigator deal with the recursion-related functionality.
Parameters
----------
importer :
Function that will be called whenever optimizations add stuff
to the graph.
pruner :
Function to be called when optimizations remove stuff
from the graph.
chin :
"on change input" called whenever a node's inputs change.
name :
name of the ``Updater`` to attach.
importer
Function to be called when a rewrite adds something to the graph.
pruner
Function to be called when a rewrite removes something from the
graph.
chin
Function to be called when a node's inputs change.
name
Name of the `DispatchingFeature` to attach.
Returns
-------
......@@ -1841,27 +1860,29 @@ class NodeProcessingGraphRewriter(GraphRewriter):
if importer is None and pruner is None:
return None
u = Updater(importer, pruner, chin, name=name)
u = DispatchingFeature(importer, pruner, chin, name=name)
fgraph.attach_feature(u)
return u
def detach_updater(self, fgraph, u):
"""Undo the work of ``attach_updater``.
def detach_updater(
self, fgraph: FunctionGraph, updater: Optional[DispatchingFeature]
):
"""Undo the work of `attach_updater`.
Parameters
----------
fgraph
The `FunctionGraph`.
u
A return-value of ``attach_updater``.
updater
The `DispatchingFeature` to remove.
Returns
-------
None
"""
if u is not None:
fgraph.remove_feature(u)
if updater is not None:
fgraph.remove_feature(updater)
def process_node(
self,
......@@ -1871,14 +1892,10 @@ class NodeProcessingGraphRewriter(GraphRewriter):
):
r"""Apply `node_rewriter` to `node`.
The :meth:`node_rewriter.transform` method will return either ``False`` or a
list of `Variable`\s that are intended to replace :attr:`node.outputs`.
If the `fgraph` accepts the replacement, then the optimization is
successful, and this function returns ``True``.
If there are no replacement candidates or the `fgraph` rejects the
replacements, this function returns ``False``.
The :meth:`node_rewriter.transform` method will return either ``False``, a
list of `Variable`\s that are intended to replace :attr:`node.outputs`, or
a ``dict`` specifying replacements--or the key ``"remove"`` mapped to a
sequence of `Variable`\s to be removed.
Parameters
----------
......@@ -1893,7 +1910,11 @@ class NodeProcessingGraphRewriter(GraphRewriter):
Returns
-------
bool
``True`` iff the `node`'s outputs were replaced in the `fgraph`.
If `fgraph` accepts the replacement, then the rewrite is
successful and this function returns ``True``. If there are no
replacement candidates, or the `fgraph` rejects the replacements,
this function returns ``False``.
"""
node_rewriter = node_rewriter or self.node_rewriter
......
......@@ -1526,7 +1526,9 @@ class GemmOptimizer(GraphRewriter):
if new_node is not node:
nodelist.append(new_node)
u = aesara.graph.opt.Updater(on_import, None, None, name="GemmOptimizer")
u = aesara.graph.opt.DispatchingFeature(
on_import, None, None, name="GemmOptimizer"
)
fgraph.attach_feature(u)
while did_something:
nb_iter += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论