提交 fbe8c876 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename Updater to DispatchingFeature and update related type hints and docstrings

上级 d5013456
...@@ -909,7 +909,9 @@ class FunctionGraph(MetaObject): ...@@ -909,7 +909,9 @@ class FunctionGraph(MetaObject):
for feature in self._features: for feature in self._features:
for attr in getattr(feature, "pickle_rm_attr", []): for attr in getattr(feature, "pickle_rm_attr", []):
del d[attr] del d[attr]
# The class Updater take fct as parameter and they are lambda function, so unpicklable.
# XXX: The `Feature` `DispatchingFeature` takes functions as parameter
# and they can be lambda functions, making them unpicklable.
# execute_callbacks_times have reference to optimizer, and they can't # execute_callbacks_times have reference to optimizer, and they can't
# be pickled as the decorators with parameters aren't pickable. # be pickled as the decorators with parameters aren't pickable.
......
...@@ -45,6 +45,12 @@ from aesara.utils import flatten ...@@ -45,6 +45,12 @@ from aesara.utils import flatten
_logger = logging.getLogger("aesara.graph.opt") _logger = logging.getLogger("aesara.graph.opt")
RemoveKeyType = Literal["remove"]
TransformOutputType = Union[
bool,
Sequence[Variable],
Dict[Union[Variable, Literal["remove"]], Union[Variable, Sequence[Variable]]],
]
FailureCallbackType = Callable[ FailureCallbackType = Callable[
[ [
Exception, Exception,
...@@ -158,11 +164,7 @@ class NodeRewriter(Rewriter): ...@@ -158,11 +164,7 @@ class NodeRewriter(Rewriter):
@abc.abstractmethod @abc.abstractmethod
def transform( def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
) -> Union[ ) -> TransformOutputType:
bool,
Sequence[Variable],
Dict[Union[Variable, Literal["remove"]], Union[Variable, Sequence[Variable]]],
]:
r"""Rewrite the sub-graph given by `node`. r"""Rewrite the sub-graph given by `node`.
Subclasses should implement this function so that it returns one of the Subclasses should implement this function so that it returns one of the
...@@ -965,9 +967,13 @@ class MetaNodeRewriter(NodeRewriter): ...@@ -965,9 +967,13 @@ class MetaNodeRewriter(NodeRewriter):
def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]): def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]):
self.rewriters.append(rewriter) self.rewriters.append(rewriter)
for c in rewriter.tracks():
self.track_dict[c].append(rewriter) tracks = rewriter.tracks()
self._tracks.append(c) if tracks:
for c in tracks:
self.track_dict[c].append(rewriter)
self._tracks.append(c)
for tag in tag_list: for tag in tag_list:
self.tag_dict[tag].append(rewriter) self.tag_dict[tag].append(rewriter)
...@@ -1695,7 +1701,9 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1695,7 +1701,9 @@ class PatternNodeRewriter(NodeRewriter):
) )
class Updater(Feature): class DispatchingFeature(Feature):
"""A `Feature` consisting of user-defined functions implementing each `Feature` callback method."""
def __init__(self, importer, pruner, chin, name=None): def __init__(self, importer, pruner, chin, name=None):
self.importer = importer self.importer = importer
self.pruner = pruner self.pruner = pruner
...@@ -1703,7 +1711,7 @@ class Updater(Feature): ...@@ -1703,7 +1711,7 @@ class Updater(Feature):
self.name = name self.name = name
def __str__(self): def __str__(self):
return f"Updater{{{self.name}}}" return f"{type(self).__name__}{{{self.name}}}"
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
if self.importer: if self.importer:
...@@ -1725,24 +1733,30 @@ class Updater(Feature): ...@@ -1725,24 +1733,30 @@ class Updater(Feature):
class NodeProcessingGraphRewriter(GraphRewriter): class NodeProcessingGraphRewriter(GraphRewriter):
r"""A rewriter that applies a `NodeRewriter` with considerations for the new nodes it creates. r"""A class providing a base implementation for applying `NodeRewriter.transform` results to a graph.
The results of successful rewrites are considered for rewriting based on This rewriter accepts the output of `NodeRewriter.transform`
the values of `NodeProcessingGraphRewriter.ignore_newtrees` and/or implementations and applies them to a `FunctionGraph`.
`NodeRewriter.reentrant`.
This rewriter accepts ``dict`` values from `NodeRewriter.transform`. It accepts a sequence of new output nodes or ``dict``s. Entries in
Entries in these ``dict``\s can be `Variable`\s and their new values. these ``dict``\s can be `Variable`\s and their new values. It also accepts
It also accepts a special ``"remove"`` key. A sequence of `Variable`\s a special ``"remove"`` key. A sequence of `Variable`\s mapped to the key
mapped to the key ``"remove"`` are removed from the `FunctionGraph`. ``"remove"`` are removed from the `FunctionGraph`.
It also adds some interface elements for simple reentrant/recursive
application of rewrites. The parameter `NodeRewriter.ignore_newtrees` is
intended to be used by subclasses, alongside the
`NodeRewriter.attach_updater` and `NodeRewriter.detach_updater` methods, to
determine whether or not sub-graphs created by rewrites are to have the
same rewrites applied to them.
""" """
@staticmethod @classmethod
def warn(exc, nav, repl_pairs, node_rewriter, node): def warn(cls, exc, nav, repl_pairs, node_rewriter, node):
"""A failure callback that prints a traceback.""" """A failure callback that prints a traceback."""
if config.on_opt_error != "ignore": if config.on_opt_error != "ignore":
_logger.error(f"Optimization failure due to: {node_rewriter}") _logger.error(f"Rewrite failure due to: {node_rewriter}")
_logger.error(f"node: {node}") _logger.error(f"node: {node}")
_logger.error("TRACEBACK:") _logger.error("TRACEBACK:")
_logger.error(traceback.format_exc()) _logger.error(traceback.format_exc())
...@@ -1753,8 +1767,8 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1753,8 +1767,8 @@ class NodeProcessingGraphRewriter(GraphRewriter):
# seriously wrong if such an exception is raised. # seriously wrong if such an exception is raised.
raise exc raise exc
@staticmethod @classmethod
def warn_inplace(exc, nav, repl_pairs, node_rewriter, node): def warn_inplace(cls, exc, nav, repl_pairs, node_rewriter, node):
r"""A failure callback that ignores `InconsistencyError`\s and prints a traceback. r"""A failure callback that ignores `InconsistencyError`\s and prints a traceback.
If the error occurred during replacement, `repl_pairs` is set; If the error occurred during replacement, `repl_pairs` is set;
...@@ -1763,12 +1777,10 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1763,12 +1777,10 @@ class NodeProcessingGraphRewriter(GraphRewriter):
""" """
if isinstance(exc, InconsistencyError): if isinstance(exc, InconsistencyError):
return return
return NodeProcessingGraphRewriter.warn( return cls.warn(exc, nav, repl_pairs, node_rewriter, node)
exc, nav, repl_pairs, node_rewriter, node
)
@staticmethod @classmethod
def warn_ignore(exc, nav, repl_pairs, node_rewriter, node): def warn_ignore(cls, exc, nav, repl_pairs, node_rewriter, node):
"""A failure callback that ignores all errors.""" """A failure callback that ignores all errors."""
def __init__( def __init__(
...@@ -1812,22 +1824,29 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1812,22 +1824,29 @@ class NodeProcessingGraphRewriter(GraphRewriter):
else: else:
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback self.failure_callback = failure_callback
super().__init__()
def attach_updater(self, fgraph, importer, pruner, chin=None, name=None): def attach_updater(
r"""Install `FunctionGraph` listeners to help the navigator deal with the ``ignore_trees``-related functionality. self,
fgraph: FunctionGraph,
importer: Optional[Callable],
pruner: Optional[Callable],
chin: Optional[Callable] = None,
name: Optional[str] = None,
) -> Optional[DispatchingFeature]:
r"""Install `FunctionGraph` listeners to help the navigator deal with the recursion-related functionality.
Parameters Parameters
---------- ----------
importer : importer
Function that will be called whenever optimizations add stuff Function to be called when a rewrite adds something to the graph.
to the graph. pruner
pruner : Function to be called when a rewrite removes something from the
Function to be called when optimizations remove stuff graph.
from the graph. chin
chin : Function to be called when a node's inputs change.
"on change input" called whenever a node's inputs change. name
name : Name of the `DispatchingFeature` to attach.
name of the ``Updater`` to attach.
Returns Returns
------- -------
...@@ -1841,27 +1860,29 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1841,27 +1860,29 @@ class NodeProcessingGraphRewriter(GraphRewriter):
if importer is None and pruner is None: if importer is None and pruner is None:
return None return None
u = Updater(importer, pruner, chin, name=name) u = DispatchingFeature(importer, pruner, chin, name=name)
fgraph.attach_feature(u) fgraph.attach_feature(u)
return u return u
def detach_updater(self, fgraph, u): def detach_updater(
"""Undo the work of ``attach_updater``. self, fgraph: FunctionGraph, updater: Optional[DispatchingFeature]
):
"""Undo the work of `attach_updater`.
Parameters Parameters
---------- ----------
fgraph fgraph
The `FunctionGraph`. The `FunctionGraph`.
u updater
A return-value of ``attach_updater``. The `DispatchingFeature` to remove.
Returns Returns
------- -------
None None
""" """
if u is not None: if updater is not None:
fgraph.remove_feature(u) fgraph.remove_feature(updater)
def process_node( def process_node(
self, self,
...@@ -1871,14 +1892,10 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1871,14 +1892,10 @@ class NodeProcessingGraphRewriter(GraphRewriter):
): ):
r"""Apply `node_rewriter` to `node`. r"""Apply `node_rewriter` to `node`.
The :meth:`node_rewriter.transform` method will return either ``False`` or a The :meth:`node_rewriter.transform` method will return either ``False``, a
list of `Variable`\s that are intended to replace :attr:`node.outputs`. list of `Variable`\s that are intended to replace :attr:`node.outputs`, or
a ``dict`` specifying replacements--or the key ``"remove"`` mapped to a
If the `fgraph` accepts the replacement, then the optimization is sequence of `Variable`\s to be removed.
successful, and this function returns ``True``.
If there are no replacement candidates or the `fgraph` rejects the
replacements, this function returns ``False``.
Parameters Parameters
---------- ----------
...@@ -1893,7 +1910,11 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1893,7 +1910,11 @@ class NodeProcessingGraphRewriter(GraphRewriter):
Returns Returns
------- -------
bool bool
``True`` iff the `node`'s outputs were replaced in the `fgraph`. If `fgraph` accepts the replacement, then the rewrite is
successful and this function returns ``True``. If there are no
replacement candidates, or the `fgraph` rejects the replacements,
this function returns ``False``.
""" """
node_rewriter = node_rewriter or self.node_rewriter node_rewriter = node_rewriter or self.node_rewriter
......
...@@ -1526,7 +1526,9 @@ class GemmOptimizer(GraphRewriter): ...@@ -1526,7 +1526,9 @@ class GemmOptimizer(GraphRewriter):
if new_node is not node: if new_node is not node:
nodelist.append(new_node) nodelist.append(new_node)
u = aesara.graph.opt.Updater(on_import, None, None, name="GemmOptimizer") u = aesara.graph.opt.DispatchingFeature(
on_import, None, None, name="GemmOptimizer"
)
fgraph.attach_feature(u) fgraph.attach_feature(u)
while did_something: while did_something:
nb_iter += 1 nb_iter += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论