提交 214ef4cf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename GlobalOptimizer to GraphRewriter

上级 7ce7b0c2
......@@ -12,7 +12,7 @@ from aesara.configdefaults import config
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.opt import (
CheckStackTraceOptimization,
GlobalOptimizer,
GraphRewriter,
MergeOptimizer,
NavigatorOptimizer,
)
......@@ -106,13 +106,13 @@ predefined_optimizers = {
def register_optimizer(name, opt):
"""Add a `GlobalOptimizer` which can be referred to by `name` in `Mode`."""
"""Add a `GraphRewriter` which can be referred to by `name` in `Mode`."""
if name in predefined_optimizers:
raise ValueError(f"Optimizer name already taken: {name}")
predefined_optimizers[name] = opt
class AddDestroyHandler(GlobalOptimizer):
class AddDestroyHandler(GraphRewriter):
"""
This optimizer performs two important functions:
......@@ -145,7 +145,7 @@ class AddDestroyHandler(GlobalOptimizer):
fgraph.attach_feature(DestroyHandler())
class AddFeatureOptimizer(GlobalOptimizer):
class AddFeatureOptimizer(GraphRewriter):
"""
This optimizer adds a provided feature to the function graph.
"""
......@@ -161,7 +161,7 @@ class AddFeatureOptimizer(GlobalOptimizer):
pass
class PrintCurrentFunctionGraph(GlobalOptimizer):
class PrintCurrentFunctionGraph(GraphRewriter):
"""
This optimizer is for debugging.
......
......@@ -83,7 +83,7 @@ class Rewriter(abc.ABC):
return id(self)
class GlobalOptimizer(Rewriter):
class GraphRewriter(Rewriter):
"""A optimizer that can be applied to a `FunctionGraph` in order to transform it.
It can represent an optimization or, in general, any kind of transformation
......@@ -96,7 +96,7 @@ class GlobalOptimizer(Rewriter):
"""Apply the optimization to a `FunctionGraph`.
It may use all the methods defined by the `FunctionGraph`. If the
`GlobalOptimizer` needs to use a certain tool, such as an
`GraphRewriter` needs to use a certain tool, such as an
`InstanceFinder`, it can do so in its `add_requirements` method.
"""
......@@ -185,8 +185,8 @@ class LocalOptimizer(Rewriter):
print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream)
class FromFunctionOptimizer(GlobalOptimizer):
"""A `GlobalOptimizer` constructed from a given function."""
class FromFunctionOptimizer(GraphRewriter):
"""A `GraphRewriter` constructed from a given function."""
def __init__(self, fn, requirements=()):
self.fn = fn
......@@ -225,8 +225,8 @@ def inplace_optimizer(f):
return rval
class SeqOptimizer(GlobalOptimizer, UserList):
"""A `GlobalOptimizer` that applies a list of optimizers sequentially."""
class SeqOptimizer(GraphRewriter, UserList):
"""A `GraphRewriter` that applies a list of optimizers sequentially."""
@staticmethod
def warn(exc, self, optimizer):
......@@ -258,7 +258,7 @@ class SeqOptimizer(GlobalOptimizer, UserList):
self.failure_callback = failure_callback
def apply(self, fgraph):
"""Applies each `GlobalOptimizer` in ``self.data`` to `fgraph`."""
"""Applies each `GraphRewriter` in ``self.data`` to `fgraph`."""
l = []
if fgraph.profile:
validate_before = fgraph.profile.validate_time
......@@ -670,7 +670,7 @@ class MergeFeature(Feature):
self.noinput_nodes.add(node)
class MergeOptimizer(GlobalOptimizer):
class MergeOptimizer(GraphRewriter):
r"""Merges parts of the graph that are identical and redundant.
The basic principle is that if two `Apply`\s have `Op`\s that compare equal, and
......@@ -1718,7 +1718,7 @@ class Updater(Feature):
self.chin = None
class NavigatorOptimizer(GlobalOptimizer):
class NavigatorOptimizer(GraphRewriter):
r"""An optimizer that applies a `LocalOptimizer` with considerations for the new nodes it creates.
......@@ -2578,7 +2578,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
+ list(opt.final_optimizers)
+ list(opt.cleanup_optimizers)
)
if o.print_profile.__code__ is not GlobalOptimizer.print_profile.__code__
if o.print_profile.__code__ is not GraphRewriter.print_profile.__code__
]
if not gf_opts:
return
......@@ -3043,7 +3043,7 @@ class CheckStackTraceFeature(Feature):
)
class CheckStackTraceOptimization(GlobalOptimizer):
class CheckStackTraceOptimization(GraphRewriter):
"""Optimizer that serves to add `CheckStackTraceOptimization` as a feature."""
def add_requirements(self, fgraph):
......@@ -3060,6 +3060,11 @@ DEPRECATED_NAMES = [
"`LocalMetaOptimizerSkipAssertionError` is deprecated: use `MetaNodeRewriterSkip` instead.",
MetaNodeRewriterSkip,
),
(
"GlobalOptimizer",
"`GlobalOptimizer` is deprecated: use `GraphRewriter` instead.",
GraphRewriter,
),
]
......
......@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict
OptimizersType = Union[aesara_opt.GlobalOptimizer, aesara_opt.LocalOptimizer]
OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.LocalOptimizer]
class OptimizationDatabase:
"""A class that represents a collection/database of optimizations.
r"""A class that represents a collection/database of optimizations.
These databases are used to logically organize collections of optimizers
(i.e. ``GlobalOptimizer``s and ``LocalOptimizer``).
(i.e. `GraphRewriter`\s and `LocalOptimizer`).
"""
def __init__(self):
......@@ -61,7 +61,7 @@ class OptimizationDatabase:
optimizer,
(
OptimizationDatabase,
aesara_opt.GlobalOptimizer,
aesara_opt.GraphRewriter,
aesara_opt.LocalOptimizer,
),
):
......@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase):
Notes
-----
We can use `LocalOptimizer` and `GlobalOptimizer` since `EquilibriumOptimizer`
We can use `LocalOptimizer` and `GraphRewriter` since `EquilibriumOptimizer`
supports both.
It is probably not a good idea to have ignore_newtrees=False and
......@@ -506,7 +506,7 @@ class LocalGroupDB(SequenceDB):
class TopoDB(OptimizationDatabase):
"""Generate a `GlobalOptimizer` of type TopoOptimizer."""
"""Generate a `GraphRewriter` of type TopoOptimizer."""
def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
......
......@@ -22,7 +22,7 @@ from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.opt import GraphRewriter, in2out, local_optimizer
from aesara.graph.type import HasDataType, HasShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
......@@ -583,7 +583,7 @@ def cond_merge_ifs_false(fgraph, node):
return op(*old_ins, return_list=True)
class CondMerge(GlobalOptimizer):
class CondMerge(GraphRewriter):
"""Graph Optimizer that merges different cond ops"""
def add_requirements(self, fgraph):
......
......@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.opt import GraphRewriter, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.type import HasShape
from aesara.graph.utils import InconsistencyError
......@@ -919,7 +919,7 @@ def push_out_add_scan(fgraph, node):
return False
class ScanInplaceOptimizer(GlobalOptimizer):
class ScanInplaceOptimizer(GraphRewriter):
"""Make `Scan`s perform in-place.
This optimization attempts to make `Scan` compute its recurrent outputs inplace
......@@ -1658,7 +1658,7 @@ def save_mem_new_scan(fgraph, node):
return False
class ScanMerge(GlobalOptimizer):
class ScanMerge(GraphRewriter):
r"""Graph optimizer that merges different scan ops.
This optimization attempts to fuse distinct `Scan` `Op`s into a single `Scan` `Op`
......
......@@ -27,7 +27,7 @@ from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.opt import (
GlobalOptimizer,
GraphRewriter,
OpRemove,
check_chain,
copy_stack_trace,
......@@ -162,7 +162,7 @@ def broadcast_like(value, template, fgraph, dtype=None):
return rval
class InplaceElemwiseOptimizer(GlobalOptimizer):
class InplaceElemwiseOptimizer(GraphRewriter):
r"""
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
......@@ -1443,7 +1443,7 @@ class ShapeFeature(Feature):
return type(self)()
class ShapeOptimizer(GlobalOptimizer):
class ShapeOptimizer(GraphRewriter):
"""Optimizer that adds `ShapeFeature` as a feature."""
def add_requirements(self, fgraph):
......@@ -1453,7 +1453,7 @@ class ShapeOptimizer(GlobalOptimizer):
pass
class UnShapeOptimizer(GlobalOptimizer):
class UnShapeOptimizer(GraphRewriter):
"""Optimizer that removes `ShapeFeature` as a feature."""
def apply(self, fgraph):
......@@ -3085,7 +3085,7 @@ def elemwise_max_input_fct(node):
local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct)
class FusionOptimizer(GlobalOptimizer):
class FusionOptimizer(GraphRewriter):
"""Graph optimizer that simply runs local fusion operations.
TODO: This is basically a `EquilibriumOptimizer`; we should just use that.
......
......@@ -147,7 +147,7 @@ from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.op import Op
from aesara.graph.opt import (
EquilibriumOptimizer,
GlobalOptimizer,
GraphRewriter,
copy_stack_trace,
in2out,
local_optimizer,
......@@ -1496,7 +1496,7 @@ def _gemm_from_node2(fgraph, node):
return None, t1 - t0, 0, 0
class GemmOptimizer(GlobalOptimizer):
class GemmOptimizer(GraphRewriter):
"""Graph optimizer for inserting Gemm operations."""
def __init__(self):
......
......@@ -39,10 +39,10 @@ we want to define are local.
.. optimizer:
Global optimization
-------------------
Graph Rewriting
---------------
.. class:: GlobalOptimizer
.. class:: GraphRewriter
.. method:: apply(fgraph)
......@@ -54,12 +54,12 @@ Global optimization
This method takes a :class:`FunctionGraph` object and adds :ref:`features
<libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed
for the :meth:`GlobalOptimizer.apply` method to do its job properly.
for the :meth:`GraphRewriter.apply` method to do its job properly.
.. method:: optimize(fgraph)
This is the interface function called by Aesara. It calls
:meth:`GlobalOptimizer.apply` by default.
:meth:`GraphRewriter.apply` by default.
Local optimization
......@@ -101,10 +101,10 @@ simplification described above:
.. testcode::
import aesara
from aesara.graph.opt import GlobalOptimizer
from aesara.graph.opt import GraphRewriter
from aesara.graph.features import ReplaceValidate
class Simplify(GlobalOptimizer):
class Simplify(GraphRewriter):
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
......@@ -136,7 +136,7 @@ another while respecting certain validation constraints. As an
exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you
want to use the method it publishes instead of the call to toposort)
Then, in :meth:`GlobalOptimizer.apply` we do the actual job of simplification. We start by
Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
iterating through the graph in topological order. For each node
encountered, we check if it's a ``div`` node. If not, we have nothing
to do here. If so, we put in ``x``, ``y`` and ``z`` the numerator,
......
......@@ -10,7 +10,7 @@ from aesara.graph.optdb import (
)
class TestOpt(opt.GlobalOptimizer):
class TestOpt(opt.GraphRewriter):
name = "blah"
def apply(self, fgraph):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论