提交 4c458590 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename FromFunctionOptimizer to FromFunctionGraphRewriter

上级 0ce6eceb
...@@ -13,7 +13,7 @@ from aesara.graph.basic import ( ...@@ -13,7 +13,7 @@ from aesara.graph.basic import (
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import node_rewriter, optimizer from aesara.graph.opt import node_rewriter, graph_rewriter
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
......
...@@ -189,7 +189,7 @@ class NodeRewriter(Rewriter): ...@@ -189,7 +189,7 @@ class NodeRewriter(Rewriter):
print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream)
class FromFunctionOptimizer(GraphRewriter): class FromFunctionGraphRewriter(GraphRewriter):
"""A `GraphRewriter` constructed from a given function.""" """A `GraphRewriter` constructed from a given function."""
def __init__(self, fn, requirements=()): def __init__(self, fn, requirements=()):
...@@ -213,18 +213,18 @@ class FromFunctionOptimizer(GraphRewriter): ...@@ -213,18 +213,18 @@ class FromFunctionOptimizer(GraphRewriter):
return self.__name__ return self.__name__
def optimizer(f): def graph_rewriter(f):
"""Decorator for `FromFunctionOptimizer`.""" """Decorator for `FromFunctionGraphRewriter`."""
rval = FromFunctionOptimizer(f) rval = FromFunctionGraphRewriter(f)
rval.__name__ = f.__name__ rval.__name__ = f.__name__
return rval return rval
def inplace_optimizer(f): def inplace_graph_rewriter(f):
"""Decorator for `FromFunctionOptimizer` that also adds the `DestroyHandler` features.""" """Decorator for `FromFunctionGraphRewriter` that also adds the `DestroyHandler` features."""
dh_handler = dh.DestroyHandler dh_handler = dh.DestroyHandler
requirements = (lambda fgraph: fgraph.attach_feature(dh_handler()),) requirements = (lambda fgraph: fgraph.attach_feature(dh_handler()),)
rval = FromFunctionOptimizer(f, requirements) rval = FromFunctionGraphRewriter(f, requirements)
rval.__name__ = f.__name__ rval.__name__ = f.__name__
return rval return rval
...@@ -3131,6 +3131,21 @@ DEPRECATED_NAMES = [ ...@@ -3131,6 +3131,21 @@ DEPRECATED_NAMES = [
"`pre_greedy_local_optimizer` is deprecated: use `pre_greedy_node_rewriter` instead.", "`pre_greedy_local_optimizer` is deprecated: use `pre_greedy_node_rewriter` instead.",
pre_greedy_node_rewriter, pre_greedy_node_rewriter,
), ),
(
"FromFunctionOptimizer",
"`FromFunctionOptimizer` is deprecated: use `FromFunctionGraphRewriter` instead.",
FromFunctionGraphRewriter,
),
(
"optimizer",
"`optimizer` is deprecated: use `graph_rewriter` instead.",
graph_rewriter,
),
(
"inplace_optimizer",
"`inplace_optimizer` is deprecated: use `graph_rewriter` instead.",
graph_rewriter,
),
] ]
......
...@@ -18,7 +18,7 @@ from aesara.compile import optdb ...@@ -18,7 +18,7 @@ from aesara.compile import optdb
from aesara.gradient import DisconnectedType, grad_not_implemented from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, node_rewriter, optimizer from aesara.graph.opt import copy_stack_trace, graph_rewriter, node_rewriter
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp from aesara.scalar import UnaryScalarOp
...@@ -1847,7 +1847,7 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot() ...@@ -1847,7 +1847,7 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
@register_stabilize("fast_compile") @register_stabilize("fast_compile")
@register_specialize("fast_compile") @register_specialize("fast_compile")
@optimizer @graph_rewriter
def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
def search_make_one_sub(): def search_make_one_sub():
for node in fgraph.toposort(): for node in fgraph.toposort():
...@@ -1874,7 +1874,7 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): ...@@ -1874,7 +1874,7 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
return return
@optimizer @graph_rewriter
def crossentropy_to_crossentropy_with_softmax(fgraph): def crossentropy_to_crossentropy_with_softmax(fgraph):
""" """
This is a stabilization rewrite that is more general than This is a stabilization rewrite that is more general than
......
...@@ -944,7 +944,7 @@ This will output something like this: ...@@ -944,7 +944,7 @@ This will output something like this:
validate_time 6.43730163574e-05 validate_time 6.43730163574e-05
callback_time 0.000783205032349 callback_time 0.000783205032349
time_toposort 0.0035240650177 time_toposort 0.0035240650177
0.090089s - ('inplace_elemwise_optimizer', 'FromFunctionOptimizer', 30) - 0.019s 0.090089s - ('inplace_elemwise_optimizer', 'FromFunctionGraphRewriter', 30) - 0.019s
0.048993s - ('BlasOpt', 'SeqOptimizer', 8) - 0.000s 0.048993s - ('BlasOpt', 'SeqOptimizer', 8) - 0.000s
SeqOptimizer BlasOpt time 0.049s for 81/80 nodes before/after optimization SeqOptimizer BlasOpt time 0.049s for 81/80 nodes before/after optimization
0.000s for fgraph.validate() 0.000s for fgraph.validate()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论