提交 214ef4cf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename GlobalOptimizer to GraphRewriter

上级 7ce7b0c2
...@@ -12,7 +12,7 @@ from aesara.configdefaults import config ...@@ -12,7 +12,7 @@ from aesara.configdefaults import config
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.opt import ( from aesara.graph.opt import (
CheckStackTraceOptimization, CheckStackTraceOptimization,
GlobalOptimizer, GraphRewriter,
MergeOptimizer, MergeOptimizer,
NavigatorOptimizer, NavigatorOptimizer,
) )
...@@ -106,13 +106,13 @@ predefined_optimizers = { ...@@ -106,13 +106,13 @@ predefined_optimizers = {
def register_optimizer(name, opt): def register_optimizer(name, opt):
"""Add a `GlobalOptimizer` which can be referred to by `name` in `Mode`.""" """Add a `GraphRewriter` which can be referred to by `name` in `Mode`."""
if name in predefined_optimizers: if name in predefined_optimizers:
raise ValueError(f"Optimizer name already taken: {name}") raise ValueError(f"Optimizer name already taken: {name}")
predefined_optimizers[name] = opt predefined_optimizers[name] = opt
class AddDestroyHandler(GlobalOptimizer): class AddDestroyHandler(GraphRewriter):
""" """
This optimizer performs two important functions: This optimizer performs two important functions:
...@@ -145,7 +145,7 @@ class AddDestroyHandler(GlobalOptimizer): ...@@ -145,7 +145,7 @@ class AddDestroyHandler(GlobalOptimizer):
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
class AddFeatureOptimizer(GlobalOptimizer): class AddFeatureOptimizer(GraphRewriter):
""" """
This optimizer adds a provided feature to the function graph. This optimizer adds a provided feature to the function graph.
""" """
...@@ -161,7 +161,7 @@ class AddFeatureOptimizer(GlobalOptimizer): ...@@ -161,7 +161,7 @@ class AddFeatureOptimizer(GlobalOptimizer):
pass pass
class PrintCurrentFunctionGraph(GlobalOptimizer): class PrintCurrentFunctionGraph(GraphRewriter):
""" """
This optimizer is for debugging. This optimizer is for debugging.
......
...@@ -83,7 +83,7 @@ class Rewriter(abc.ABC): ...@@ -83,7 +83,7 @@ class Rewriter(abc.ABC):
return id(self) return id(self)
class GlobalOptimizer(Rewriter): class GraphRewriter(Rewriter):
"""A optimizer that can be applied to a `FunctionGraph` in order to transform it. """A optimizer that can be applied to a `FunctionGraph` in order to transform it.
It can represent an optimization or, in general, any kind of transformation It can represent an optimization or, in general, any kind of transformation
...@@ -96,7 +96,7 @@ class GlobalOptimizer(Rewriter): ...@@ -96,7 +96,7 @@ class GlobalOptimizer(Rewriter):
"""Apply the optimization to a `FunctionGraph`. """Apply the optimization to a `FunctionGraph`.
It may use all the methods defined by the `FunctionGraph`. If the It may use all the methods defined by the `FunctionGraph`. If the
`GlobalOptimizer` needs to use a certain tool, such as an `GraphRewriter` needs to use a certain tool, such as an
`InstanceFinder`, it can do so in its `add_requirements` method. `InstanceFinder`, it can do so in its `add_requirements` method.
""" """
...@@ -185,8 +185,8 @@ class LocalOptimizer(Rewriter): ...@@ -185,8 +185,8 @@ class LocalOptimizer(Rewriter):
print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream)
class FromFunctionOptimizer(GlobalOptimizer): class FromFunctionOptimizer(GraphRewriter):
"""A `GlobalOptimizer` constructed from a given function.""" """A `GraphRewriter` constructed from a given function."""
def __init__(self, fn, requirements=()): def __init__(self, fn, requirements=()):
self.fn = fn self.fn = fn
...@@ -225,8 +225,8 @@ def inplace_optimizer(f): ...@@ -225,8 +225,8 @@ def inplace_optimizer(f):
return rval return rval
class SeqOptimizer(GlobalOptimizer, UserList): class SeqOptimizer(GraphRewriter, UserList):
"""A `GlobalOptimizer` that applies a list of optimizers sequentially.""" """A `GraphRewriter` that applies a list of optimizers sequentially."""
@staticmethod @staticmethod
def warn(exc, self, optimizer): def warn(exc, self, optimizer):
...@@ -258,7 +258,7 @@ class SeqOptimizer(GlobalOptimizer, UserList): ...@@ -258,7 +258,7 @@ class SeqOptimizer(GlobalOptimizer, UserList):
self.failure_callback = failure_callback self.failure_callback = failure_callback
def apply(self, fgraph): def apply(self, fgraph):
"""Applies each `GlobalOptimizer` in ``self.data`` to `fgraph`.""" """Applies each `GraphRewriter` in ``self.data`` to `fgraph`."""
l = [] l = []
if fgraph.profile: if fgraph.profile:
validate_before = fgraph.profile.validate_time validate_before = fgraph.profile.validate_time
...@@ -670,7 +670,7 @@ class MergeFeature(Feature): ...@@ -670,7 +670,7 @@ class MergeFeature(Feature):
self.noinput_nodes.add(node) self.noinput_nodes.add(node)
class MergeOptimizer(GlobalOptimizer): class MergeOptimizer(GraphRewriter):
r"""Merges parts of the graph that are identical and redundant. r"""Merges parts of the graph that are identical and redundant.
The basic principle is that if two `Apply`\s have `Op`\s that compare equal, and The basic principle is that if two `Apply`\s have `Op`\s that compare equal, and
...@@ -1718,7 +1718,7 @@ class Updater(Feature): ...@@ -1718,7 +1718,7 @@ class Updater(Feature):
self.chin = None self.chin = None
class NavigatorOptimizer(GlobalOptimizer): class NavigatorOptimizer(GraphRewriter):
r"""An optimizer that applies a `LocalOptimizer` with considerations for the new nodes it creates. r"""An optimizer that applies a `LocalOptimizer` with considerations for the new nodes it creates.
...@@ -2578,7 +2578,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2578,7 +2578,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
+ list(opt.final_optimizers) + list(opt.final_optimizers)
+ list(opt.cleanup_optimizers) + list(opt.cleanup_optimizers)
) )
if o.print_profile.__code__ is not GlobalOptimizer.print_profile.__code__ if o.print_profile.__code__ is not GraphRewriter.print_profile.__code__
] ]
if not gf_opts: if not gf_opts:
return return
...@@ -3043,7 +3043,7 @@ class CheckStackTraceFeature(Feature): ...@@ -3043,7 +3043,7 @@ class CheckStackTraceFeature(Feature):
) )
class CheckStackTraceOptimization(GlobalOptimizer): class CheckStackTraceOptimization(GraphRewriter):
"""Optimizer that serves to add `CheckStackTraceOptimization` as a feature.""" """Optimizer that serves to add `CheckStackTraceOptimization` as a feature."""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
...@@ -3060,6 +3060,11 @@ DEPRECATED_NAMES = [ ...@@ -3060,6 +3060,11 @@ DEPRECATED_NAMES = [
"`LocalMetaOptimizerSkipAssertionError` is deprecated: use `MetaNodeRewriterSkip` instead.", "`LocalMetaOptimizerSkipAssertionError` is deprecated: use `MetaNodeRewriterSkip` instead.",
MetaNodeRewriterSkip, MetaNodeRewriterSkip,
), ),
(
"GlobalOptimizer",
"`GlobalOptimizer` is deprecated: use `GraphRewriter` instead.",
GraphRewriter,
),
] ]
......
...@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet ...@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict from aesara.utils import DefaultOrderedDict
OptimizersType = Union[aesara_opt.GlobalOptimizer, aesara_opt.LocalOptimizer] OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.LocalOptimizer]
class OptimizationDatabase: class OptimizationDatabase:
"""A class that represents a collection/database of optimizations. r"""A class that represents a collection/database of optimizations.
These databases are used to logically organize collections of optimizers These databases are used to logically organize collections of optimizers
(i.e. ``GlobalOptimizer``s and ``LocalOptimizer``). (i.e. `GraphRewriter`\s and `LocalOptimizer`).
""" """
def __init__(self): def __init__(self):
...@@ -61,7 +61,7 @@ class OptimizationDatabase: ...@@ -61,7 +61,7 @@ class OptimizationDatabase:
optimizer, optimizer,
( (
OptimizationDatabase, OptimizationDatabase,
aesara_opt.GlobalOptimizer, aesara_opt.GraphRewriter,
aesara_opt.LocalOptimizer, aesara_opt.LocalOptimizer,
), ),
): ):
...@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase):
Notes Notes
----- -----
We can use `LocalOptimizer` and `GlobalOptimizer` since `EquilibriumOptimizer` We can use `LocalOptimizer` and `GraphRewriter` since `EquilibriumOptimizer`
supports both. supports both.
It is probably not a good idea to have ignore_newtrees=False and It is probably not a good idea to have ignore_newtrees=False and
...@@ -506,7 +506,7 @@ class LocalGroupDB(SequenceDB): ...@@ -506,7 +506,7 @@ class LocalGroupDB(SequenceDB):
class TopoDB(OptimizationDatabase): class TopoDB(OptimizationDatabase):
"""Generate a `GlobalOptimizer` of type TopoOptimizer.""" """Generate a `GraphRewriter` of type TopoOptimizer."""
def __init__( def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
......
...@@ -22,7 +22,7 @@ from aesara.compile import optdb ...@@ -22,7 +22,7 @@ from aesara.compile import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GraphRewriter, in2out, local_optimizer
from aesara.graph.type import HasDataType, HasShape from aesara.graph.type import HasDataType, HasShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
...@@ -583,7 +583,7 @@ def cond_merge_ifs_false(fgraph, node): ...@@ -583,7 +583,7 @@ def cond_merge_ifs_false(fgraph, node):
return op(*old_ins, return_list=True) return op(*old_ins, return_list=True)
class CondMerge(GlobalOptimizer): class CondMerge(GraphRewriter):
"""Graph Optimizer that merges different cond ops""" """Graph Optimizer that merges different cond ops"""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
......
...@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler ...@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GraphRewriter, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.type import HasShape from aesara.graph.type import HasShape
from aesara.graph.utils import InconsistencyError from aesara.graph.utils import InconsistencyError
...@@ -919,7 +919,7 @@ def push_out_add_scan(fgraph, node): ...@@ -919,7 +919,7 @@ def push_out_add_scan(fgraph, node):
return False return False
class ScanInplaceOptimizer(GlobalOptimizer): class ScanInplaceOptimizer(GraphRewriter):
"""Make `Scan`s perform in-place. """Make `Scan`s perform in-place.
This optimization attempts to make `Scan` compute its recurrent outputs inplace This optimization attempts to make `Scan` compute its recurrent outputs inplace
...@@ -1658,7 +1658,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1658,7 +1658,7 @@ def save_mem_new_scan(fgraph, node):
return False return False
class ScanMerge(GlobalOptimizer): class ScanMerge(GraphRewriter):
r"""Graph optimizer that merges different scan ops. r"""Graph optimizer that merges different scan ops.
This optimization attempts to fuse distinct `Scan` `Op`s into a single `Scan` `Op` This optimization attempts to fuse distinct `Scan` `Op`s into a single `Scan` `Op`
......
...@@ -27,7 +27,7 @@ from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate ...@@ -27,7 +27,7 @@ from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value, get_test_value from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.opt import ( from aesara.graph.opt import (
GlobalOptimizer, GraphRewriter,
OpRemove, OpRemove,
check_chain, check_chain,
copy_stack_trace, copy_stack_trace,
...@@ -162,7 +162,7 @@ def broadcast_like(value, template, fgraph, dtype=None): ...@@ -162,7 +162,7 @@ def broadcast_like(value, template, fgraph, dtype=None):
return rval return rval
class InplaceElemwiseOptimizer(GlobalOptimizer): class InplaceElemwiseOptimizer(GraphRewriter):
r""" r"""
This is parameterized so that it works for `Elemwise` `Op`\s. This is parameterized so that it works for `Elemwise` `Op`\s.
""" """
...@@ -1443,7 +1443,7 @@ class ShapeFeature(Feature): ...@@ -1443,7 +1443,7 @@ class ShapeFeature(Feature):
return type(self)() return type(self)()
class ShapeOptimizer(GlobalOptimizer): class ShapeOptimizer(GraphRewriter):
"""Optimizer that adds `ShapeFeature` as a feature.""" """Optimizer that adds `ShapeFeature` as a feature."""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
...@@ -1453,7 +1453,7 @@ class ShapeOptimizer(GlobalOptimizer): ...@@ -1453,7 +1453,7 @@ class ShapeOptimizer(GlobalOptimizer):
pass pass
class UnShapeOptimizer(GlobalOptimizer): class UnShapeOptimizer(GraphRewriter):
"""Optimizer that removes `ShapeFeature` as a feature.""" """Optimizer that removes `ShapeFeature` as a feature."""
def apply(self, fgraph): def apply(self, fgraph):
...@@ -3085,7 +3085,7 @@ def elemwise_max_input_fct(node): ...@@ -3085,7 +3085,7 @@ def elemwise_max_input_fct(node):
local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct)
class FusionOptimizer(GlobalOptimizer): class FusionOptimizer(GraphRewriter):
"""Graph optimizer that simply runs local fusion operations. """Graph optimizer that simply runs local fusion operations.
TODO: This is basically a `EquilibriumOptimizer`; we should just use that. TODO: This is basically a `EquilibriumOptimizer`; we should just use that.
......
...@@ -147,7 +147,7 @@ from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate ...@@ -147,7 +147,7 @@ from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
GlobalOptimizer, GraphRewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
local_optimizer, local_optimizer,
...@@ -1496,7 +1496,7 @@ def _gemm_from_node2(fgraph, node): ...@@ -1496,7 +1496,7 @@ def _gemm_from_node2(fgraph, node):
return None, t1 - t0, 0, 0 return None, t1 - t0, 0, 0
class GemmOptimizer(GlobalOptimizer): class GemmOptimizer(GraphRewriter):
"""Graph optimizer for inserting Gemm operations.""" """Graph optimizer for inserting Gemm operations."""
def __init__(self): def __init__(self):
......
...@@ -39,10 +39,10 @@ we want to define are local. ...@@ -39,10 +39,10 @@ we want to define are local.
.. optimizer: .. optimizer:
Global optimization Graph Rewriting
------------------- ---------------
.. class:: GlobalOptimizer .. class:: GraphRewriter
.. method:: apply(fgraph) .. method:: apply(fgraph)
...@@ -54,12 +54,12 @@ Global optimization ...@@ -54,12 +54,12 @@ Global optimization
This method takes a :class:`FunctionGraph` object and adds :ref:`features This method takes a :class:`FunctionGraph` object and adds :ref:`features
<libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed <libdoc_graph_fgraphfeature>` to it. These features are "plugins" that are needed
for the :meth:`GlobalOptimizer.apply` method to do its job properly. for the :meth:`GraphRewriter.apply` method to do its job properly.
.. method:: optimize(fgraph) .. method:: optimize(fgraph)
This is the interface function called by Aesara. It calls This is the interface function called by Aesara. It calls
:meth:`GlobalOptimizer.apply` by default. :meth:`GraphRewriter.apply` by default.
Local optimization Local optimization
...@@ -101,10 +101,10 @@ simplification described above: ...@@ -101,10 +101,10 @@ simplification described above:
.. testcode:: .. testcode::
import aesara import aesara
from aesara.graph.opt import GlobalOptimizer from aesara.graph.opt import GraphRewriter
from aesara.graph.features import ReplaceValidate from aesara.graph.features import ReplaceValidate
class Simplify(GlobalOptimizer): class Simplify(GraphRewriter):
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate()) fgraph.attach_feature(ReplaceValidate())
...@@ -136,7 +136,7 @@ another while respecting certain validation constraints. As an ...@@ -136,7 +136,7 @@ another while respecting certain validation constraints. As an
exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you
want to use the method it publishes instead of the call to toposort) want to use the method it publishes instead of the call to toposort)
Then, in :meth:`GlobalOptimizer.apply` we do the actual job of simplification. We start by Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
iterating through the graph in topological order. For each node iterating through the graph in topological order. For each node
encountered, we check if it's a ``div`` node. If not, we have nothing encountered, we check if it's a ``div`` node. If not, we have nothing
to do here. If so, we put in ``x``, ``y`` and ``z`` the numerator, to do here. If so, we put in ``x``, ``y`` and ``z`` the numerator,
......
...@@ -10,7 +10,7 @@ from aesara.graph.optdb import ( ...@@ -10,7 +10,7 @@ from aesara.graph.optdb import (
) )
class TestOpt(opt.GlobalOptimizer): class TestOpt(opt.GraphRewriter):
name = "blah" name = "blah"
def apply(self, fgraph): def apply(self, fgraph):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论