提交 89655604 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename LocalMetaOptimizer to MetaNodeRewriter

上级 4c458590
......@@ -17,7 +17,9 @@ from collections import UserList, defaultdict, deque
from collections.abc import Iterable
from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Callable, Dict
from typing import Iterable as IterableType
from typing import List, Optional, Sequence, Tuple, Union, cast
from typing_extensions import Literal
......@@ -57,7 +59,7 @@ FailureCallbackType = Callable[
class MetaNodeRewriterSkip(AssertionError):
"""This is an `AssertionError`, but instead of having the
`LocalMetaOptimizer` print the error, it just skip that
`MetaNodeRewriter` print the error, it just skip that
compilation.
"""
......@@ -943,9 +945,9 @@ def pre_constant_merge(fgraph, variables):
return [recursive_merge(v) for v in variables]
class LocalMetaOptimizer(NodeRewriter):
class MetaNodeRewriter(NodeRewriter):
r"""
Base class for meta-optimizers that try a set of `NodeRewriter`\s
Base class for meta-rewriters that try a set of `NodeRewriter`\s
to replace a node and choose the one that executes the fastest.
If the error `MetaNodeRewriterSkip` is raised during
......@@ -959,15 +961,15 @@ class LocalMetaOptimizer(NodeRewriter):
self.track_dict = defaultdict(lambda: [])
self.tag_dict = defaultdict(lambda: [])
self._tracks = []
self.optimizers = []
self.rewriters = []
def register(self, optimizer, tag_list):
self.optimizers.append(optimizer)
for c in optimizer.tracks():
self.track_dict[c].append(optimizer)
def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]):
self.rewriters.append(rewriter)
for c in rewriter.tracks():
self.track_dict[c].append(rewriter)
self._tracks.append(c)
for tag in tag_list:
self.tag_dict[tag].append(optimizer)
self.tag_dict[tag].append(rewriter)
def tracks(self):
return self._tracks
......@@ -1000,19 +1002,19 @@ class LocalMetaOptimizer(NodeRewriter):
if missing:
if self.verbose > 0:
print(
f"{self.__class__.__name__} cannot meta-optimize {node}, "
f"{self.__class__.__name__} cannot meta-rewrite {node}, "
f"{len(missing)} of {int(node.nin)} input shapes unknown"
)
return
# now we can apply the different optimizations in turn,
# now we can apply the different rewrites in turn,
# compile the resulting subgraphs and time their execution
if self.verbose > 1:
print(
f"{self.__class__.__name__} meta-optimizing {node} ({len(self.get_opts(node))} choices):"
f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):"
)
timings = []
for opt in self.get_opts(node):
outputs = opt.transform(fgraph, node, *args, **kwargs)
for node_rewriter in self.get_rewrites(node):
outputs = node_rewriter.transform(fgraph, node, *args, **kwargs)
if outputs:
try:
fn = aesara.function(
......@@ -1024,15 +1026,15 @@ class LocalMetaOptimizer(NodeRewriter):
continue
except Exception as e:
if self.verbose > 0:
print(f"* {opt}: exception", e)
print(f"* {node_rewriter}: exception", e)
continue
else:
if self.verbose > 1:
print(f"* {opt}: {timing:.5g} sec")
timings.append((timing, outputs, opt))
print(f"* {node_rewriter}: {timing:.5g} sec")
timings.append((timing, outputs, node_rewriter))
else:
if self.verbose > 0:
print(f"* {opt}: not applicable")
print(f"* {node_rewriter}: not applicable")
# finally, we choose the fastest one
if timings:
timings.sort()
......@@ -1049,8 +1051,8 @@ class LocalMetaOptimizer(NodeRewriter):
"""
raise NotImplementedError()
def get_opts(self, node):
"""Return the optimizations that apply to `node`.
def get_rewrites(self, node):
"""Return the rewrites that apply to `node`.
This uses ``self.track_dict[type(node.op)]`` by default.
"""
......@@ -3146,6 +3148,11 @@ DEPRECATED_NAMES = [
"`inplace_optimizer` is deprecated: use `graph_rewriter` instead.",
graph_rewriter,
),
(
"LocalMetaOptimizer",
"`LocalMetaOptimizer` is deprecated: use `MetaNodeRewriter` instead.",
MetaNodeRewriter,
),
]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论