提交 2929f425 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename FromFunctionLocalOptimizer to FromFunctionNodeRewriter

上级 a04ac4bc
...@@ -1064,7 +1064,7 @@ class MetaNodeRewriter(NodeRewriter): ...@@ -1064,7 +1064,7 @@ class MetaNodeRewriter(NodeRewriter):
return time.time() - start return time.time() - start
class FromFunctionLocalOptimizer(NodeRewriter): class FromFunctionNodeRewriter(NodeRewriter):
"""A `NodeRewriter` constructed from a function.""" """A `NodeRewriter` constructed from a function."""
def __init__(self, fn, tracks=None, requirements=()): def __init__(self, fn, tracks=None, requirements=()):
...@@ -1095,7 +1095,7 @@ class FromFunctionLocalOptimizer(NodeRewriter): ...@@ -1095,7 +1095,7 @@ class FromFunctionLocalOptimizer(NodeRewriter):
return getattr(self, "__name__", repr(self)) return getattr(self, "__name__", repr(self))
def __repr__(self): def __repr__(self):
return f"FromFunctionLocalOptimizer({repr(self.fn)}, {repr(self._tracks)}, {repr(self.requirements)})" return f"FromFunctionNodeRewriter({repr(self.fn)}, {repr(self._tracks)}, {repr(self.requirements)})"
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.transform} id={id(self)}", file=stream) print(f"{' ' * level}{self.transform} id={id(self)}", file=stream)
...@@ -1106,7 +1106,7 @@ def node_rewriter( ...@@ -1106,7 +1106,7 @@ def node_rewriter(
inplace: bool = False, inplace: bool = False,
requirements: Optional[Tuple[type, ...]] = (), requirements: Optional[Tuple[type, ...]] = (),
): ):
r"""A decorator used to construct `FromFunctionLocalOptimizer` instances. r"""A decorator used to construct `FromFunctionNodeRewriter` instances.
Parameters Parameters
---------- ----------
...@@ -1145,7 +1145,7 @@ def node_rewriter( ...@@ -1145,7 +1145,7 @@ def node_rewriter(
req = tuple(requirements) + ( req = tuple(requirements) + (
lambda fgraph: fgraph.attach_feature(dh_handler()), lambda fgraph: fgraph.attach_feature(dh_handler()),
) )
rval = FromFunctionLocalOptimizer(f, tracks, req) rval = FromFunctionNodeRewriter(f, tracks, req)
rval.__name__ = f.__name__ rval.__name__ = f.__name__
return rval return rval
...@@ -3158,6 +3158,11 @@ DEPRECATED_NAMES = [ ...@@ -3158,6 +3158,11 @@ DEPRECATED_NAMES = [
"`SeqOptimizer` is deprecated: use `SequentialGraphRewriter` instead.", "`SeqOptimizer` is deprecated: use `SequentialGraphRewriter` instead.",
SequentialGraphRewriter, SequentialGraphRewriter,
), ),
(
"FromFunctionLocalOptimizer",
"`FromFunctionLocalOptimizer` is deprecated: use `FromFunctionNodeRewriter` instead.",
FromFunctionNodeRewriter,
),
] ]
......
...@@ -710,7 +710,7 @@ def test_node_rewriter_str(): ...@@ -710,7 +710,7 @@ def test_node_rewriter_str():
assert str(local_opt_1) == "local_opt_1" assert str(local_opt_1) == "local_opt_1"
res = repr(local_opt_1) res = repr(local_opt_1)
assert res.startswith("FromFunctionLocalOptimizer(") assert res.startswith("FromFunctionNodeRewriter(")
assert "Op1" in res assert "Op1" in res
assert "local_opt_1" in res assert "local_opt_1" in res
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论