提交 5dbfd046 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename OptimizationDatabase to RewriteDatabase

上级 ac213377
...@@ -19,8 +19,8 @@ from aesara.graph.opt import ( ...@@ -19,8 +19,8 @@ from aesara.graph.opt import (
from aesara.graph.optdb import ( from aesara.graph.optdb import (
EquilibriumDB, EquilibriumDB,
LocalGroupDB, LocalGroupDB,
OptimizationDatabase,
OptimizationQuery, OptimizationQuery,
RewriteDatabase,
SequenceDB, SequenceDB,
TopoDB, TopoDB,
) )
...@@ -288,7 +288,7 @@ class Mode: ...@@ -288,7 +288,7 @@ class Mode:
A Linker decides which implementations to use (C or Python, for example) A Linker decides which implementations to use (C or Python, for example)
and how to string them together to perform the computation. and how to string them together to perform the computation.
db: db:
The ``OptimizationDatabase`` used by this ``Mode``. Note: This value The ``RewriteDatabase`` used by this ``Mode``. Note: This value
is *not* part of a ``Mode`` instance's pickled state. is *not* part of a ``Mode`` instance's pickled state.
See Also See Also
...@@ -303,7 +303,7 @@ class Mode: ...@@ -303,7 +303,7 @@ class Mode:
self, self,
linker: Optional[Union[str, Linker]] = None, linker: Optional[Union[str, Linker]] = None,
optimizer: Union[str, OptimizationQuery] = "default", optimizer: Union[str, OptimizationQuery] = "default",
db: OptimizationDatabase = None, db: RewriteDatabase = None,
): ):
if linker is None: if linker is None:
linker = config.linker linker = config.linker
......
...@@ -14,7 +14,7 @@ from aesara.utils import DefaultOrderedDict ...@@ -14,7 +14,7 @@ from aesara.utils import DefaultOrderedDict
OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.NodeRewriter] OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.NodeRewriter]
class OptimizationDatabase: class RewriteDatabase:
r"""A class that represents a collection/database of optimizations. r"""A class that represents a collection/database of optimizations.
These databases are used to logically organize collections of optimizers These databases are used to logically organize collections of optimizers
...@@ -31,7 +31,7 @@ class OptimizationDatabase: ...@@ -31,7 +31,7 @@ class OptimizationDatabase:
def register( def register(
self, self,
name: str, name: str,
rewriter: Union["OptimizationDatabase", OptimizersType], rewriter: Union["RewriteDatabase", OptimizersType],
*tags: str, *tags: str,
use_db_name_as_tag=True, use_db_name_as_tag=True,
): ):
...@@ -59,7 +59,7 @@ class OptimizationDatabase: ...@@ -59,7 +59,7 @@ class OptimizationDatabase:
if not isinstance( if not isinstance(
rewriter, rewriter,
( (
OptimizationDatabase, RewriteDatabase,
aesara_opt.GraphRewriter, aesara_opt.GraphRewriter,
aesara_opt.NodeRewriter, aesara_opt.NodeRewriter,
), ),
...@@ -75,12 +75,12 @@ class OptimizationDatabase: ...@@ -75,12 +75,12 @@ class OptimizationDatabase:
rewriter.name = name rewriter.name = name
# This restriction is there because in many place we suppose that # This restriction is there because in many place we suppose that
# something in the OptimizationDatabase is there only once. # something in the RewriteDatabase is there only once.
if rewriter.name in self.__db__: if rewriter.name in self.__db__:
raise ValueError( raise ValueError(
f"Tried to register {rewriter.name} again under the new name {name}. " f"Tried to register {rewriter.name} again under the new name {name}. "
"The same optimization cannot be registered multiple times in" "The same optimization cannot be registered multiple times in"
" an ``OptimizationDatabase``; use ProxyDB instead." " an ``RewriteDatabase``; use ProxyDB instead."
) )
self.__db__[name] = OrderedSet([rewriter]) self.__db__[name] = OrderedSet([rewriter])
self._names.add(name) self._names.add(name)
...@@ -121,7 +121,7 @@ class OptimizationDatabase: ...@@ -121,7 +121,7 @@ class OptimizationDatabase:
remove = OrderedSet() remove = OrderedSet()
add = OrderedSet() add = OrderedSet()
for obj in variables: for obj in variables:
if isinstance(obj, OptimizationDatabase): if isinstance(obj, RewriteDatabase):
def_sub_query = q def_sub_query = q
if q.extra_optimizations: if q.extra_optimizations:
def_sub_query = copy.copy(q) def_sub_query = copy.copy(q)
...@@ -288,7 +288,7 @@ class OptimizationQuery: ...@@ -288,7 +288,7 @@ class OptimizationQuery:
) )
class EquilibriumDB(OptimizationDatabase): class EquilibriumDB(RewriteDatabase):
"""A database of rewrites that should be applied until equilibrium is reached. """A database of rewrites that should be applied until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium rewriters. Canonicalize, Stabilize, and Specialize are all equilibrium rewriters.
...@@ -327,7 +327,7 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -327,7 +327,7 @@ class EquilibriumDB(OptimizationDatabase):
def register( def register(
self, self,
name: str, name: str,
rewriter: Union["OptimizationDatabase", OptimizersType], rewriter: Union["RewriteDatabase", OptimizersType],
*tags: str, *tags: str,
final_rewriter: bool = False, final_rewriter: bool = False,
cleanup: bool = False, cleanup: bool = False,
...@@ -365,7 +365,7 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -365,7 +365,7 @@ class EquilibriumDB(OptimizationDatabase):
) )
class SequenceDB(OptimizationDatabase): class SequenceDB(RewriteDatabase):
"""A sequence of potential rewrites. """A sequence of potential rewrites.
Retrieve a sequence of rewrites as a `SequentialGraphRewriter` by calling Retrieve a sequence of rewrites as a `SequentialGraphRewriter` by calling
...@@ -497,7 +497,7 @@ class LocalGroupDB(SequenceDB): ...@@ -497,7 +497,7 @@ class LocalGroupDB(SequenceDB):
return ret return ret
class TopoDB(OptimizationDatabase): class TopoDB(RewriteDatabase):
"""Generate a `GraphRewriter` of type `WalkingGraphRewriter`.""" """Generate a `GraphRewriter` of type `WalkingGraphRewriter`."""
def __init__( def __init__(
...@@ -518,17 +518,17 @@ class TopoDB(OptimizationDatabase): ...@@ -518,17 +518,17 @@ class TopoDB(OptimizationDatabase):
) )
class ProxyDB(OptimizationDatabase): class ProxyDB(RewriteDatabase):
"""A object that wraps an existing ``OptimizationDatabase``. """A object that wraps an existing ``RewriteDatabase``.
This is needed because we can't register the same ``OptimizationDatabase`` This is needed because we can't register the same ``RewriteDatabase``
multiple times in different positions in a ``SequentialDB``. multiple times in different positions in a ``SequentialDB``.
""" """
def __init__(self, db): def __init__(self, db):
if not isinstance(db, OptimizationDatabase): if not isinstance(db, RewriteDatabase):
raise TypeError("`db` must be an `OptimizationDatabase`.") raise TypeError("`db` must be an `RewriteDatabase`.")
self.db = db self.db = db
...@@ -539,14 +539,19 @@ class ProxyDB(OptimizationDatabase): ...@@ -539,14 +539,19 @@ class ProxyDB(OptimizationDatabase):
DEPRECATED_NAMES = [ DEPRECATED_NAMES = [
( (
"DB", "DB",
"`DB` is deprecated; use `OptimizationDatabase` instead.", "`DB` is deprecated; use `RewriteDatabase` instead.",
OptimizationDatabase, RewriteDatabase,
), ),
( (
"Query", "Query",
"`Query` is deprecated; use `OptimizationQuery` instead.", "`Query` is deprecated; use `OptimizationQuery` instead.",
OptimizationQuery, OptimizationQuery,
), ),
(
"OptimizationDatabase",
"`OptimizationDatabase` is deprecated; use `RewriteDatabase` instead.",
RewriteDatabase,
),
] ]
......
...@@ -36,7 +36,7 @@ from aesara.graph.opt import ( ...@@ -36,7 +36,7 @@ from aesara.graph.opt import (
in2out, in2out,
node_rewriter, node_rewriter,
) )
from aesara.graph.optdb import OptimizationDatabase, SequenceDB from aesara.graph.optdb import RewriteDatabase, SequenceDB
from aesara.graph.utils import ( from aesara.graph.utils import (
InconsistencyError, InconsistencyError,
MethodNotDefined, MethodNotDefined,
...@@ -479,11 +479,11 @@ compile.optdb.register( ...@@ -479,11 +479,11 @@ compile.optdb.register(
def register_useless( def register_useless(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags, **kwargs node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs
): ):
if isinstance(node_rewriter, str): if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]): def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_useless(inner_rewriter, node_rewriter, *tags, **kwargs) return register_useless(inner_rewriter, node_rewriter, *tags, **kwargs)
return register return register
...@@ -497,11 +497,11 @@ def register_useless( ...@@ -497,11 +497,11 @@ def register_useless(
def register_canonicalize( def register_canonicalize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
): ):
if isinstance(node_rewriter, str): if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]): def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs) return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register return register
...@@ -514,11 +514,11 @@ def register_canonicalize( ...@@ -514,11 +514,11 @@ def register_canonicalize(
def register_stabilize( def register_stabilize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
): ):
if isinstance(node_rewriter, str): if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]): def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_stabilize(inner_rewriter, node_rewriter, *tags, **kwargs) return register_stabilize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register return register
...@@ -531,11 +531,11 @@ def register_stabilize( ...@@ -531,11 +531,11 @@ def register_stabilize(
def register_specialize( def register_specialize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
): ):
if isinstance(node_rewriter, str): if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]): def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs) return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register return register
...@@ -548,11 +548,11 @@ def register_specialize( ...@@ -548,11 +548,11 @@ def register_specialize(
def register_uncanonicalize( def register_uncanonicalize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
): ):
if isinstance(node_rewriter, str): if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]): def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_uncanonicalize( return register_uncanonicalize(
inner_rewriter, node_rewriter, *tags, **kwargs inner_rewriter, node_rewriter, *tags, **kwargs
) )
...@@ -567,11 +567,11 @@ def register_uncanonicalize( ...@@ -567,11 +567,11 @@ def register_uncanonicalize(
def register_specialize_device( def register_specialize_device(
node_rewriter: Union[OptimizationDatabase, Rewriter, str], *tags: str, **kwargs node_rewriter: Union[RewriteDatabase, Rewriter, str], *tags: str, **kwargs
): ):
if isinstance(node_rewriter, str): if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]): def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_specialize_device( return register_specialize_device(
inner_rewriter, node_rewriter, *tags, **kwargs inner_rewriter, node_rewriter, *tags, **kwargs
) )
......
...@@ -583,30 +583,30 @@ Definition of :obj:`optdb` ...@@ -583,30 +583,30 @@ Definition of :obj:`optdb`
:obj:`optdb` is an object which is an instance of :obj:`optdb` is an object which is an instance of
:class:`SequenceDB <optdb.SequenceDB>`, :class:`SequenceDB <optdb.SequenceDB>`,
itself a subclass of :class:`OptimizationDatabase <optdb.OptimizationDatabase>`. itself a subclass of :class:`RewriteDatabase <optdb.RewriteDatabase>`.
There exist (for now) two types of :class:`OptimizationDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`. There exist (for now) two types of :class:`RewriteDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`.
When given an appropriate :class:`OptimizationQuery`, :class:`OptimizationDatabase` objects build an :class:`Optimizer` matching When given an appropriate :class:`OptimizationQuery`, :class:`RewriteDatabase` objects build an :class:`Optimizer` matching
the query. the query.
A :class:`SequenceDB` contains :class:`Optimizer` or :class:`OptimizationDatabase` objects. Each of them A :class:`SequenceDB` contains :class:`Optimizer` or :class:`RewriteDatabase` objects. Each of them
has a name, an arbitrary number of tags and an integer representing their order has a name, an arbitrary number of tags and an integer representing their order
in the sequence. When a :class:`OptimizationQuery` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose in the sequence. When a :class:`OptimizationQuery` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose
tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which
is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase` is returned. If the :class:`SequenceDB` contains :class:`RewriteDatabase`
instances, the :class:`OptimizationQuery` will be passed to them as well and the instances, the :class:`OptimizationQuery` will be passed to them as well and the
optimizers they return will be put in their places. optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`RewriteDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the :class:`SequenceDB` contains :class:`RewriteDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the :class:`OptimizationQuery` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places :class:`NodeRewriter`\s they return will be put in their places
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`NodeRewriter` objects, so this (note that as of yet no :class:`RewriteDatabase` can produce :class:`NodeRewriter` objects, so this
is a moot point). is a moot point).
Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which Aesara contains one principal :class:`RewriteDatabase` object, :class:`optdb`, which
contains all of Aesara's optimizers with proper tags. It is contains all of Aesara's optimizers with proper tags. It is
recommended to insert new :class:`Optimizer`\s in it. As mentioned previously, recommended to insert new :class:`Optimizer`\s in it. As mentioned previously,
optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence
......
...@@ -4,8 +4,8 @@ from aesara.graph import opt ...@@ -4,8 +4,8 @@ from aesara.graph import opt
from aesara.graph.optdb import ( from aesara.graph.optdb import (
EquilibriumDB, EquilibriumDB,
LocalGroupDB, LocalGroupDB,
OptimizationDatabase,
ProxyDB, ProxyDB,
RewriteDatabase,
SequenceDB, SequenceDB,
) )
...@@ -19,7 +19,7 @@ class TestOpt(opt.GraphRewriter): ...@@ -19,7 +19,7 @@ class TestOpt(opt.GraphRewriter):
class TestDB: class TestDB:
def test_register(self): def test_register(self):
db = OptimizationDatabase() db = RewriteDatabase()
db.register("a", TestOpt()) db.register("a", TestOpt())
db.register("b", TestOpt()) db.register("b", TestOpt())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论