提交 5dbfd046 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename OptimizationDatabase to RewriteDatabase

上级 ac213377
......@@ -19,8 +19,8 @@ from aesara.graph.opt import (
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationDatabase,
OptimizationQuery,
RewriteDatabase,
SequenceDB,
TopoDB,
)
......@@ -288,7 +288,7 @@ class Mode:
A Linker decides which implementations to use (C or Python, for example)
and how to string them together to perform the computation.
db:
The ``OptimizationDatabase`` used by this ``Mode``. Note: This value
The ``RewriteDatabase`` used by this ``Mode``. Note: This value
is *not* part of a ``Mode`` instance's pickled state.
See Also
......@@ -303,7 +303,7 @@ class Mode:
self,
linker: Optional[Union[str, Linker]] = None,
optimizer: Union[str, OptimizationQuery] = "default",
db: OptimizationDatabase = None,
db: RewriteDatabase = None,
):
if linker is None:
linker = config.linker
......
......@@ -14,7 +14,7 @@ from aesara.utils import DefaultOrderedDict
OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.NodeRewriter]
class OptimizationDatabase:
class RewriteDatabase:
r"""A class that represents a collection/database of optimizations.
These databases are used to logically organize collections of optimizers
......@@ -31,7 +31,7 @@ class OptimizationDatabase:
def register(
self,
name: str,
rewriter: Union["OptimizationDatabase", OptimizersType],
rewriter: Union["RewriteDatabase", OptimizersType],
*tags: str,
use_db_name_as_tag=True,
):
......@@ -59,7 +59,7 @@ class OptimizationDatabase:
if not isinstance(
rewriter,
(
OptimizationDatabase,
RewriteDatabase,
aesara_opt.GraphRewriter,
aesara_opt.NodeRewriter,
),
......@@ -75,12 +75,12 @@ class OptimizationDatabase:
rewriter.name = name
# This restriction is there because in many place we suppose that
# something in the OptimizationDatabase is there only once.
# something in the RewriteDatabase is there only once.
if rewriter.name in self.__db__:
raise ValueError(
f"Tried to register {rewriter.name} again under the new name {name}. "
"The same optimization cannot be registered multiple times in"
" an ``OptimizationDatabase``; use ProxyDB instead."
" an ``RewriteDatabase``; use ProxyDB instead."
)
self.__db__[name] = OrderedSet([rewriter])
self._names.add(name)
......@@ -121,7 +121,7 @@ class OptimizationDatabase:
remove = OrderedSet()
add = OrderedSet()
for obj in variables:
if isinstance(obj, OptimizationDatabase):
if isinstance(obj, RewriteDatabase):
def_sub_query = q
if q.extra_optimizations:
def_sub_query = copy.copy(q)
......@@ -288,7 +288,7 @@ class OptimizationQuery:
)
class EquilibriumDB(OptimizationDatabase):
class EquilibriumDB(RewriteDatabase):
"""A database of rewrites that should be applied until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium rewriters.
......@@ -327,7 +327,7 @@ class EquilibriumDB(OptimizationDatabase):
def register(
self,
name: str,
rewriter: Union["OptimizationDatabase", OptimizersType],
rewriter: Union["RewriteDatabase", OptimizersType],
*tags: str,
final_rewriter: bool = False,
cleanup: bool = False,
......@@ -365,7 +365,7 @@ class EquilibriumDB(OptimizationDatabase):
)
class SequenceDB(OptimizationDatabase):
class SequenceDB(RewriteDatabase):
"""A sequence of potential rewrites.
Retrieve a sequence of rewrites as a `SequentialGraphRewriter` by calling
......@@ -497,7 +497,7 @@ class LocalGroupDB(SequenceDB):
return ret
class TopoDB(OptimizationDatabase):
class TopoDB(RewriteDatabase):
"""Generate a `GraphRewriter` of type `WalkingGraphRewriter`."""
def __init__(
......@@ -518,17 +518,17 @@ class TopoDB(OptimizationDatabase):
)
class ProxyDB(OptimizationDatabase):
"""A object that wraps an existing ``OptimizationDatabase``.
class ProxyDB(RewriteDatabase):
"""A object that wraps an existing ``RewriteDatabase``.
This is needed because we can't register the same ``OptimizationDatabase``
This is needed because we can't register the same ``RewriteDatabase``
multiple times in different positions in a ``SequentialDB``.
"""
def __init__(self, db):
if not isinstance(db, OptimizationDatabase):
raise TypeError("`db` must be an `OptimizationDatabase`.")
if not isinstance(db, RewriteDatabase):
raise TypeError("`db` must be an `RewriteDatabase`.")
self.db = db
......@@ -539,14 +539,19 @@ class ProxyDB(OptimizationDatabase):
DEPRECATED_NAMES = [
(
"DB",
"`DB` is deprecated; use `OptimizationDatabase` instead.",
OptimizationDatabase,
"`DB` is deprecated; use `RewriteDatabase` instead.",
RewriteDatabase,
),
(
"Query",
"`Query` is deprecated; use `OptimizationQuery` instead.",
OptimizationQuery,
),
(
"OptimizationDatabase",
"`OptimizationDatabase` is deprecated; use `RewriteDatabase` instead.",
RewriteDatabase,
),
]
......
......@@ -36,7 +36,7 @@ from aesara.graph.opt import (
in2out,
node_rewriter,
)
from aesara.graph.optdb import OptimizationDatabase, SequenceDB
from aesara.graph.optdb import RewriteDatabase, SequenceDB
from aesara.graph.utils import (
InconsistencyError,
MethodNotDefined,
......@@ -479,11 +479,11 @@ compile.optdb.register(
def register_useless(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags, **kwargs
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_useless(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
......@@ -497,11 +497,11 @@ def register_useless(
def register_canonicalize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
......@@ -514,11 +514,11 @@ def register_canonicalize(
def register_stabilize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_stabilize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
......@@ -531,11 +531,11 @@ def register_stabilize(
def register_specialize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
......@@ -548,11 +548,11 @@ def register_specialize(
def register_uncanonicalize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_uncanonicalize(
inner_rewriter, node_rewriter, *tags, **kwargs
)
......@@ -567,11 +567,11 @@ def register_uncanonicalize(
def register_specialize_device(
node_rewriter: Union[OptimizationDatabase, Rewriter, str], *tags: str, **kwargs
node_rewriter: Union[RewriteDatabase, Rewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_specialize_device(
inner_rewriter, node_rewriter, *tags, **kwargs
)
......
......@@ -583,30 +583,30 @@ Definition of :obj:`optdb`
:obj:`optdb` is an object which is an instance of
:class:`SequenceDB <optdb.SequenceDB>`,
itself a subclass of :class:`OptimizationDatabase <optdb.OptimizationDatabase>`.
There exist (for now) two types of :class:`OptimizationDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`.
When given an appropriate :class:`OptimizationQuery`, :class:`OptimizationDatabase` objects build an :class:`Optimizer` matching
itself a subclass of :class:`RewriteDatabase <optdb.RewriteDatabase>`.
There exist (for now) two types of :class:`RewriteDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`.
When given an appropriate :class:`OptimizationQuery`, :class:`RewriteDatabase` objects build an :class:`Optimizer` matching
the query.
A :class:`SequenceDB` contains :class:`Optimizer` or :class:`OptimizationDatabase` objects. Each of them
A :class:`SequenceDB` contains :class:`Optimizer` or :class:`RewriteDatabase` objects. Each of them
has a name, an arbitrary number of tags and an integer representing their order
in the sequence. When a :class:`OptimizationQuery` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose
tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which
is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase`
is returned. If the :class:`SequenceDB` contains :class:`RewriteDatabase`
instances, the :class:`OptimizationQuery` will be passed to them as well and the
optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`RewriteDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`SequenceDB` contains :class:`RewriteDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`NodeRewriter` objects, so this
(note that as of yet no :class:`RewriteDatabase` can produce :class:`NodeRewriter` objects, so this
is a moot point).
Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which
Aesara contains one principal :class:`RewriteDatabase` object, :class:`optdb`, which
contains all of Aesara's optimizers with proper tags. It is
recommended to insert new :class:`Optimizer`\s in it. As mentioned previously,
optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence
......
......@@ -4,8 +4,8 @@ from aesara.graph import opt
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationDatabase,
ProxyDB,
RewriteDatabase,
SequenceDB,
)
......@@ -19,7 +19,7 @@ class TestOpt(opt.GraphRewriter):
class TestDB:
def test_register(self):
db = OptimizationDatabase()
db = RewriteDatabase()
db.register("a", TestOpt())
db.register("b", TestOpt())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论