提交 89655604 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename LocalMetaOptimizer to MetaNodeRewriter

上级 4c458590
...@@ -17,7 +17,9 @@ from collections import UserList, defaultdict, deque ...@@ -17,7 +17,9 @@ from collections import UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import _compose_mro, partial, reduce # type: ignore from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain from itertools import chain
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast from typing import Callable, Dict
from typing import Iterable as IterableType
from typing import List, Optional, Sequence, Tuple, Union, cast
from typing_extensions import Literal from typing_extensions import Literal
...@@ -57,7 +59,7 @@ FailureCallbackType = Callable[ ...@@ -57,7 +59,7 @@ FailureCallbackType = Callable[
class MetaNodeRewriterSkip(AssertionError): class MetaNodeRewriterSkip(AssertionError):
"""This is an `AssertionError`, but instead of having the """This is an `AssertionError`, but instead of having the
`LocalMetaOptimizer` print the error, it just skip that `MetaNodeRewriter` print the error, it just skip that
compilation. compilation.
""" """
...@@ -943,9 +945,9 @@ def pre_constant_merge(fgraph, variables): ...@@ -943,9 +945,9 @@ def pre_constant_merge(fgraph, variables):
return [recursive_merge(v) for v in variables] return [recursive_merge(v) for v in variables]
class LocalMetaOptimizer(NodeRewriter): class MetaNodeRewriter(NodeRewriter):
r""" r"""
Base class for meta-optimizers that try a set of `NodeRewriter`\s Base class for meta-rewriters that try a set of `NodeRewriter`\s
to replace a node and choose the one that executes the fastest. to replace a node and choose the one that executes the fastest.
If the error `MetaNodeRewriterSkip` is raised during If the error `MetaNodeRewriterSkip` is raised during
...@@ -959,15 +961,15 @@ class LocalMetaOptimizer(NodeRewriter): ...@@ -959,15 +961,15 @@ class LocalMetaOptimizer(NodeRewriter):
self.track_dict = defaultdict(lambda: []) self.track_dict = defaultdict(lambda: [])
self.tag_dict = defaultdict(lambda: []) self.tag_dict = defaultdict(lambda: [])
self._tracks = [] self._tracks = []
self.optimizers = [] self.rewriters = []
def register(self, optimizer, tag_list): def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]):
self.optimizers.append(optimizer) self.rewriters.append(rewriter)
for c in optimizer.tracks(): for c in rewriter.tracks():
self.track_dict[c].append(optimizer) self.track_dict[c].append(rewriter)
self._tracks.append(c) self._tracks.append(c)
for tag in tag_list: for tag in tag_list:
self.tag_dict[tag].append(optimizer) self.tag_dict[tag].append(rewriter)
def tracks(self): def tracks(self):
return self._tracks return self._tracks
...@@ -1000,19 +1002,19 @@ class LocalMetaOptimizer(NodeRewriter): ...@@ -1000,19 +1002,19 @@ class LocalMetaOptimizer(NodeRewriter):
if missing: if missing:
if self.verbose > 0: if self.verbose > 0:
print( print(
f"{self.__class__.__name__} cannot meta-optimize {node}, " f"{self.__class__.__name__} cannot meta-rewrite {node}, "
f"{len(missing)} of {int(node.nin)} input shapes unknown" f"{len(missing)} of {int(node.nin)} input shapes unknown"
) )
return return
# now we can apply the different optimizations in turn, # now we can apply the different rewrites in turn,
# compile the resulting subgraphs and time their execution # compile the resulting subgraphs and time their execution
if self.verbose > 1: if self.verbose > 1:
print( print(
f"{self.__class__.__name__} meta-optimizing {node} ({len(self.get_opts(node))} choices):" f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):"
) )
timings = [] timings = []
for opt in self.get_opts(node): for node_rewriter in self.get_rewrites(node):
outputs = opt.transform(fgraph, node, *args, **kwargs) outputs = node_rewriter.transform(fgraph, node, *args, **kwargs)
if outputs: if outputs:
try: try:
fn = aesara.function( fn = aesara.function(
...@@ -1024,15 +1026,15 @@ class LocalMetaOptimizer(NodeRewriter): ...@@ -1024,15 +1026,15 @@ class LocalMetaOptimizer(NodeRewriter):
continue continue
except Exception as e: except Exception as e:
if self.verbose > 0: if self.verbose > 0:
print(f"* {opt}: exception", e) print(f"* {node_rewriter}: exception", e)
continue continue
else: else:
if self.verbose > 1: if self.verbose > 1:
print(f"* {opt}: {timing:.5g} sec") print(f"* {node_rewriter}: {timing:.5g} sec")
timings.append((timing, outputs, opt)) timings.append((timing, outputs, node_rewriter))
else: else:
if self.verbose > 0: if self.verbose > 0:
print(f"* {opt}: not applicable") print(f"* {node_rewriter}: not applicable")
# finally, we choose the fastest one # finally, we choose the fastest one
if timings: if timings:
timings.sort() timings.sort()
...@@ -1049,8 +1051,8 @@ class LocalMetaOptimizer(NodeRewriter): ...@@ -1049,8 +1051,8 @@ class LocalMetaOptimizer(NodeRewriter):
""" """
raise NotImplementedError() raise NotImplementedError()
def get_opts(self, node): def get_rewrites(self, node):
"""Return the optimizations that apply to `node`. """Return the rewrites that apply to `node`.
This uses ``self.track_dict[type(node.op)]`` by default. This uses ``self.track_dict[type(node.op)]`` by default.
""" """
...@@ -3146,6 +3148,11 @@ DEPRECATED_NAMES = [ ...@@ -3146,6 +3148,11 @@ DEPRECATED_NAMES = [
"`inplace_optimizer` is deprecated: use `graph_rewriter` instead.", "`inplace_optimizer` is deprecated: use `graph_rewriter` instead.",
graph_rewriter, graph_rewriter,
), ),
(
"LocalMetaOptimizer",
"`LocalMetaOptimizer` is deprecated: use `MetaNodeRewriter` instead.",
MetaNodeRewriter,
),
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论