提交 92a3b2a6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename OptimizationQuery to RewriteDatabaseQuery

上级 5dbfd046
...@@ -19,8 +19,8 @@ from aesara.graph.opt import ( ...@@ -19,8 +19,8 @@ from aesara.graph.opt import (
from aesara.graph.optdb import ( from aesara.graph.optdb import (
EquilibriumDB, EquilibriumDB,
LocalGroupDB, LocalGroupDB,
OptimizationQuery,
RewriteDatabase, RewriteDatabase,
RewriteDatabaseQuery,
SequenceDB, SequenceDB,
TopoDB, TopoDB,
) )
...@@ -64,15 +64,15 @@ def register_linker(name, linker): ...@@ -64,15 +64,15 @@ def register_linker(name, linker):
exclude = [] exclude = []
if not config.cxx: if not config.cxx:
exclude = ["cxx_only"] exclude = ["cxx_only"]
OPT_NONE = OptimizationQuery(include=[], exclude=exclude) OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't # Even if multiple merge optimizer call will be there, this shouldn't
# impact performance. # impact performance.
OPT_MERGE = OptimizationQuery(include=["merge"], exclude=exclude) OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
OPT_FAST_RUN = OptimizationQuery(include=["fast_run"], exclude=exclude) OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable") OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
OPT_FAST_COMPILE = OptimizationQuery(include=["fast_compile"], exclude=exclude) OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude)
OPT_STABILIZE = OptimizationQuery(include=["fast_run"], exclude=exclude) OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001 OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE" OPT_NONE.name = "OPT_NONE"
OPT_MERGE.name = "OPT_MERGE" OPT_MERGE.name = "OPT_MERGE"
...@@ -302,7 +302,7 @@ class Mode: ...@@ -302,7 +302,7 @@ class Mode:
def __init__( def __init__(
self, self,
linker: Optional[Union[str, Linker]] = None, linker: Optional[Union[str, Linker]] = None,
optimizer: Union[str, OptimizationQuery] = "default", optimizer: Union[str, RewriteDatabaseQuery] = "default",
db: RewriteDatabase = None, db: RewriteDatabase = None,
): ):
if linker is None: if linker is None:
...@@ -320,7 +320,7 @@ class Mode: ...@@ -320,7 +320,7 @@ class Mode:
# self.provided_optimizer - typically the `optimizer` arg. # self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined # But if the `optimizer` arg is keyword corresponding to a predefined
# OptimizationQuery, then this stores the query # RewriteDatabaseQuery, then this stores the query
# self._optimizer - typically same as provided_optimizer?? # self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying # self.__get_optimizer - returns self._optimizer (possibly querying
...@@ -342,7 +342,7 @@ class Mode: ...@@ -342,7 +342,7 @@ class Mode:
self.linker = linker self.linker = linker
if isinstance(optimizer, str) or optimizer is None: if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer] optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, OptimizationQuery): if isinstance(optimizer, RewriteDatabaseQuery):
self.provided_optimizer = optimizer self.provided_optimizer = optimizer
self._optimizer = optimizer self._optimizer = optimizer
self.call_time = 0 self.call_time = 0
...@@ -357,7 +357,7 @@ class Mode: ...@@ -357,7 +357,7 @@ class Mode:
) )
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, OptimizationQuery): if isinstance(self._optimizer, RewriteDatabaseQuery):
return self.optdb.query(self._optimizer) return self.optdb.query(self._optimizer)
else: else:
return self._optimizer return self._optimizer
...@@ -375,7 +375,7 @@ class Mode: ...@@ -375,7 +375,7 @@ class Mode:
link, opt = self.get_linker_optimizer( link, opt = self.get_linker_optimizer(
self.provided_linker, self.provided_optimizer self.provided_linker, self.provided_optimizer
) )
# N.B. opt might be a OptimizationQuery instance, not sure what else it might be... # N.B. opt might be a RewriteDatabaseQuery instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows??? # string? Optimizer? OptDB? who knows???
return self.clone(optimizer=opt.including(*tags), linker=link) return self.clone(optimizer=opt.including(*tags), linker=link)
...@@ -448,11 +448,11 @@ else: ...@@ -448,11 +448,11 @@ else:
JAX = Mode( JAX = Mode(
JAXLinker(), JAXLinker(),
OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]), RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
) )
NUMBA = Mode( NUMBA = Mode(
NumbaLinker(), NumbaLinker(),
OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]), RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
) )
......
...@@ -15,6 +15,6 @@ from aesara.graph.type import Type ...@@ -15,6 +15,6 @@ from aesara.graph.type import Type
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import node_rewriter, graph_rewriter from aesara.graph.opt import node_rewriter, graph_rewriter
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
# isort: on # isort: on
...@@ -10,7 +10,7 @@ from aesara.graph.basic import ( ...@@ -10,7 +10,7 @@ from aesara.graph.basic import (
vars_between, vars_between,
) )
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
def optimize_graph( def optimize_graph(
...@@ -34,7 +34,7 @@ def optimize_graph( ...@@ -34,7 +34,7 @@ def optimize_graph(
clone: clone:
Whether or not to clone the input graph before optimizing. Whether or not to clone the input graph before optimizing.
**kwargs: **kwargs:
Keyword arguments passed to the ``aesara.graph.optdb.OptimizationQuery`` object. Keyword arguments passed to the ``aesara.graph.optdb.RewriteDatabaseQuery`` object.
""" """
from aesara.compile import optdb from aesara.compile import optdb
...@@ -43,7 +43,7 @@ def optimize_graph( ...@@ -43,7 +43,7 @@ def optimize_graph(
fgraph = FunctionGraph(outputs=[fgraph], clone=clone) fgraph = FunctionGraph(outputs=[fgraph], clone=clone)
return_only_out = True return_only_out = True
canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs)) canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph) _ = canonicalize_opt.optimize(fgraph)
if custom_opt: if custom_opt:
......
...@@ -137,10 +137,10 @@ class RewriteDatabase: ...@@ -137,10 +137,10 @@ class RewriteDatabase:
return variables return variables
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery): if len(tags) >= 1 and isinstance(tags[0], RewriteDatabaseQuery):
if len(tags) > 1 or kwtags: if len(tags) > 1 or kwtags:
raise TypeError( raise TypeError(
"If the first argument to query is an `OptimizationQuery`," "If the first argument to query is an `RewriteDatabaseQuery`,"
" there should be no other arguments." " there should be no other arguments."
) )
return self.__query__(tags[0]) return self.__query__(tags[0])
...@@ -153,7 +153,7 @@ class RewriteDatabase: ...@@ -153,7 +153,7 @@ class RewriteDatabase:
" characters: '+', '&' or '-'" " characters: '+', '&' or '-'"
) )
return self.__query__( return self.__query__(
OptimizationQuery( RewriteDatabaseQuery(
include=include, require=require, exclude=exclude, subquery=kwtags include=include, require=require, exclude=exclude, subquery=kwtags
) )
) )
...@@ -176,7 +176,7 @@ class RewriteDatabase: ...@@ -176,7 +176,7 @@ class RewriteDatabase:
print(" db", self.__db__, file=stream) print(" db", self.__db__, file=stream)
class OptimizationQuery: class RewriteDatabaseQuery:
"""An object that specifies a set of optimizations by tag/name.""" """An object that specifies a set of optimizations by tag/name."""
def __init__( def __init__(
...@@ -184,11 +184,11 @@ class OptimizationQuery: ...@@ -184,11 +184,11 @@ class OptimizationQuery:
include: Iterable[str], include: Iterable[str],
require: Optional[Union[OrderedSet, Sequence[str]]] = None, require: Optional[Union[OrderedSet, Sequence[str]]] = None,
exclude: Optional[Union[OrderedSet, Sequence[str]]] = None, exclude: Optional[Union[OrderedSet, Sequence[str]]] = None,
subquery: Optional[Dict[str, "OptimizationQuery"]] = None, subquery: Optional[Dict[str, "RewriteDatabaseQuery"]] = None,
position_cutoff: float = math.inf, position_cutoff: float = math.inf,
extra_optimizations: Optional[ extra_optimizations: Optional[
Sequence[ Sequence[
Tuple[Union["OptimizationQuery", OptimizersType], Union[int, float]] Tuple[Union["RewriteDatabaseQuery", OptimizersType], Union[int, float]]
] ]
] = None, ] = None,
): ):
...@@ -198,19 +198,19 @@ class OptimizationQuery: ...@@ -198,19 +198,19 @@ class OptimizationQuery:
========== ==========
include: include:
A set of tags such that every optimization obtained through this A set of tags such that every optimization obtained through this
`OptimizationQuery` must have **one** of the tags listed. This `RewriteDatabaseQuery` must have **one** of the tags listed. This
field is required and basically acts as a starting point for the field is required and basically acts as a starting point for the
search. search.
require: require:
A set of tags such that every optimization obtained through this A set of tags such that every optimization obtained through this
`OptimizationQuery` must have **all** of these tags. `RewriteDatabaseQuery` must have **all** of these tags.
exclude: exclude:
A set of tags such that every optimization obtained through this A set of tags such that every optimization obtained through this
``OptimizationQuery` must have **none** of these tags. ``RewriteDatabaseQuery` must have **none** of these tags.
subquery: subquery:
A dictionary mapping the name of a sub-database to a special A dictionary mapping the name of a sub-database to a special
`OptimizationQuery`. If no subquery is given for a sub-database, `RewriteDatabaseQuery`. If no subquery is given for a sub-database,
the original `OptimizationQuery` will be used again. the original `RewriteDatabaseQuery` will be used again.
position_cutoff: position_cutoff:
Only optimizations with position less than the cutoff are returned. Only optimizations with position less than the cutoff are returned.
extra_optimizations: extra_optimizations:
...@@ -229,7 +229,7 @@ class OptimizationQuery: ...@@ -229,7 +229,7 @@ class OptimizationQuery:
def __str__(self): def __str__(self):
return ( return (
"OptimizationQuery(" "RewriteDatabaseQuery("
+ f"inc={self.include},ex={self.exclude}," + f"inc={self.include},ex={self.exclude},"
+ f"require={self.require},subquery={self.subquery}," + f"require={self.require},subquery={self.subquery},"
+ f"position_cutoff={self.position_cutoff}," + f"position_cutoff={self.position_cutoff},"
...@@ -241,9 +241,9 @@ class OptimizationQuery: ...@@ -241,9 +241,9 @@ class OptimizationQuery:
if not hasattr(self, "extra_optimizations"): if not hasattr(self, "extra_optimizations"):
self.extra_optimizations = [] self.extra_optimizations = []
def including(self, *tags: str) -> "OptimizationQuery": def including(self, *tags: str) -> "RewriteDatabaseQuery":
"""Add rewrites with the given tags.""" """Add rewrites with the given tags."""
return OptimizationQuery( return RewriteDatabaseQuery(
self.include.union(tags), self.include.union(tags),
self.require, self.require,
self.exclude, self.exclude,
...@@ -252,9 +252,9 @@ class OptimizationQuery: ...@@ -252,9 +252,9 @@ class OptimizationQuery:
self.extra_optimizations, self.extra_optimizations,
) )
def excluding(self, *tags: str) -> "OptimizationQuery": def excluding(self, *tags: str) -> "RewriteDatabaseQuery":
"""Remove rewrites with the given tags.""" """Remove rewrites with the given tags."""
return OptimizationQuery( return RewriteDatabaseQuery(
self.include, self.include,
self.require, self.require,
self.exclude.union(tags), self.exclude.union(tags),
...@@ -263,9 +263,9 @@ class OptimizationQuery: ...@@ -263,9 +263,9 @@ class OptimizationQuery:
self.extra_optimizations, self.extra_optimizations,
) )
def requiring(self, *tags: str) -> "OptimizationQuery": def requiring(self, *tags: str) -> "RewriteDatabaseQuery":
"""Filter for rewrites with the given tags.""" """Filter for rewrites with the given tags."""
return OptimizationQuery( return RewriteDatabaseQuery(
self.include, self.include,
self.require.union(tags), self.require.union(tags),
self.exclude, self.exclude,
...@@ -275,10 +275,10 @@ class OptimizationQuery: ...@@ -275,10 +275,10 @@ class OptimizationQuery:
) )
def register( def register(
self, *optimizations: Tuple["OptimizationQuery", Union[int, float]] self, *optimizations: Tuple["RewriteDatabaseQuery", Union[int, float]]
) -> "OptimizationQuery": ) -> "RewriteDatabaseQuery":
"""Include the given optimizations.""" """Include the given optimizations."""
return OptimizationQuery( return RewriteDatabaseQuery(
self.include, self.include,
self.require, self.require,
self.exclude, self.exclude,
...@@ -417,13 +417,13 @@ class SequenceDB(RewriteDatabase): ...@@ -417,13 +417,13 @@ class SequenceDB(RewriteDatabase):
position_dict = self.__position__ position_dict = self.__position__
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery): if len(tags) >= 1 and isinstance(tags[0], RewriteDatabaseQuery):
# the call to super should have raise an error with a good message # the call to super should have raise an error with a good message
assert len(tags) == 1 assert len(tags) == 1
if getattr(tags[0], "position_cutoff", None): if getattr(tags[0], "position_cutoff", None):
position_cutoff = tags[0].position_cutoff position_cutoff = tags[0].position_cutoff
# The OptimizationQuery instance might contain extra optimizations which need # The RewriteDatabaseQuery instance might contain extra optimizations which need
# to be added the the sequence of optimizations (don't alter the # to be added the the sequence of optimizations (don't alter the
# original dictionary) # original dictionary)
if len(tags[0].extra_optimizations) > 0: if len(tags[0].extra_optimizations) > 0:
...@@ -544,14 +544,19 @@ DEPRECATED_NAMES = [ ...@@ -544,14 +544,19 @@ DEPRECATED_NAMES = [
), ),
( (
"Query", "Query",
"`Query` is deprecated; use `OptimizationQuery` instead.", "`Query` is deprecated; use `RewriteDatabaseQuery` instead.",
OptimizationQuery, RewriteDatabaseQuery,
), ),
( (
"OptimizationDatabase", "OptimizationDatabase",
"`OptimizationDatabase` is deprecated; use `RewriteDatabase` instead.", "`OptimizationDatabase` is deprecated; use `RewriteDatabase` instead.",
RewriteDatabase, RewriteDatabase,
), ),
(
"OptimizationQuery",
"`OptimizationQuery` is deprecated; use `RewriteDatabaseQuery` instead.",
RewriteDatabaseQuery,
),
] ]
......
...@@ -585,23 +585,23 @@ Definition of :obj:`optdb` ...@@ -585,23 +585,23 @@ Definition of :obj:`optdb`
:class:`SequenceDB <optdb.SequenceDB>`, :class:`SequenceDB <optdb.SequenceDB>`,
itself a subclass of :class:`RewriteDatabase <optdb.RewriteDatabase>`. itself a subclass of :class:`RewriteDatabase <optdb.RewriteDatabase>`.
There exist (for now) two types of :class:`RewriteDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`. There exist (for now) two types of :class:`RewriteDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`.
When given an appropriate :class:`OptimizationQuery`, :class:`RewriteDatabase` objects build an :class:`Optimizer` matching When given an appropriate :class:`RewriteDatabaseQuery`, :class:`RewriteDatabase` objects build an :class:`Optimizer` matching
the query. the query.
A :class:`SequenceDB` contains :class:`Optimizer` or :class:`RewriteDatabase` objects. Each of them A :class:`SequenceDB` contains :class:`Optimizer` or :class:`RewriteDatabase` objects. Each of them
has a name, an arbitrary number of tags and an integer representing their order has a name, an arbitrary number of tags and an integer representing their order
in the sequence. When a :class:`OptimizationQuery` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose in the sequence. When a :class:`RewriteDatabaseQuery` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose
tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which
is returned. If the :class:`SequenceDB` contains :class:`RewriteDatabase` is returned. If the :class:`SequenceDB` contains :class:`RewriteDatabase`
instances, the :class:`OptimizationQuery` will be passed to them as well and the instances, the :class:`RewriteDatabaseQuery` will be passed to them as well and the
optimizers they return will be put in their places. optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`RewriteDatabase` objects. Each of them An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`RewriteDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to has a name and an arbitrary number of tags. When a :class:`RewriteDatabaseQuery` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the
:class:`SequenceDB` contains :class:`RewriteDatabase` instances, the :class:`SequenceDB` contains :class:`RewriteDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the :class:`RewriteDatabaseQuery` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places :class:`NodeRewriter`\s they return will be put in their places
(note that as of yet no :class:`RewriteDatabase` can produce :class:`NodeRewriter` objects, so this (note that as of yet no :class:`RewriteDatabase` can produce :class:`NodeRewriter` objects, so this
is a moot point). is a moot point).
...@@ -613,68 +613,68 @@ optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence ...@@ -613,68 +613,68 @@ optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence
of global optimizations to the computation graphs. of global optimizations to the computation graphs.
:class:`OptimizationQuery` :class:`RewriteDatabaseQuery`
-------------------------- -----------------------------
A :class:`OptimizationQuery` is built by the following call: A :class:`RewriteDatabaseQuery` is built by the following call:
.. code-block:: python .. code-block:: python
aesara.graph.optdb.OptimizationQuery(include, require=None, exclude=None, subquery=None) aesara.graph.optdb.RewriteDatabaseQuery(include, require=None, exclude=None, subquery=None)
.. class:: OptimizationQuery .. class:: RewriteDatabaseQuery
.. attribute:: include .. attribute:: include
A set of tags (a tag being a string) such that every A set of tags (a tag being a string) such that every
optimization obtained through this :class:`OptimizationQuery` must have **one** of the tags optimization obtained through this :class:`RewriteDatabaseQuery` must have **one** of the tags
listed. This field is required and basically acts as a starting point listed. This field is required and basically acts as a starting point
for the search. for the search.
.. attribute:: require .. attribute:: require
A set of tags such that every optimization obtained A set of tags such that every optimization obtained
through this :class:`OptimizationQuery` must have **all** of these tags. through this :class:`RewriteDatabaseQuery` must have **all** of these tags.
.. attribute:: exclude .. attribute:: exclude
A set of tags such that every optimization obtained A set of tags such that every optimization obtained
through this :class:`OptimizationQuery` must have **none** of these tags. through this :class:`RewriteDatabaseQuery` must have **none** of these tags.
.. attribute:: subquery .. attribute:: subquery
:obj:`optdb` can contain sub-databases; subquery is a :obj:`optdb` can contain sub-databases; subquery is a
dictionary mapping the name of a sub-database to a special :class:`OptimizationQuery`. dictionary mapping the name of a sub-database to a special :class:`RewriteDatabaseQuery`.
If no subquery is given for a sub-database, the original :class:`OptimizationQuery` will be If no subquery is given for a sub-database, the original :class:`RewriteDatabaseQuery` will be
used again. used again.
Furthermore, a :class:`OptimizationQuery` object includes three methods, :meth:`including`, Furthermore, a :class:`RewriteDatabaseQuery` object includes three methods, :meth:`including`,
:meth:`requiring` and :meth:`excluding`, which each produce a new :class:`OptimizationQuery` object :meth:`requiring` and :meth:`excluding`, which each produce a new :class:`RewriteDatabaseQuery` object
with the include, require, and exclude sets refined to contain the new entries. with the include, require, and exclude sets refined to contain the new entries.
Examples Examples
-------- --------
Here are a few examples of how to use a :class:`OptimizationQuery` on :obj:`optdb` to produce an Here are a few examples of how to use a :class:`RewriteDatabaseQuery` on :obj:`optdb` to produce an
:class:`Optimizer`: :class:`Optimizer`:
.. testcode:: .. testcode::
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.compile import optdb from aesara.compile import optdb
# This is how the optimizer for the fast_run mode is defined # This is how the optimizer for the fast_run mode is defined
fast_run = optdb.query(OptimizationQuery(include=['fast_run'])) fast_run = optdb.query(RewriteDatabaseQuery(include=['fast_run']))
# This is how the optimizer for the fast_compile mode is defined # This is how the optimizer for the fast_compile mode is defined
fast_compile = optdb.query(OptimizationQuery(include=['fast_compile'])) fast_compile = optdb.query(RewriteDatabaseQuery(include=['fast_compile']))
# This is the same as fast_run but no optimizations will replace # This is the same as fast_run but no optimizations will replace
# any operation by an inplace version. This assumes, of course, # any operation by an inplace version. This assumes, of course,
# that all inplace operations are tagged as 'inplace' (as they # that all inplace operations are tagged as 'inplace' (as they
# should!) # should!)
fast_run_no_inplace = optdb.query(OptimizationQuery(include=['fast_run'], fast_run_no_inplace = optdb.query(RewriteDatabaseQuery(include=['fast_run'],
exclude=['inplace'])) exclude=['inplace']))
...@@ -733,7 +733,7 @@ optimizations: ...@@ -733,7 +733,7 @@ optimizations:
For each group, all optimizations of the group that are selected by For each group, all optimizations of the group that are selected by
the :class:`OptimizationQuery` will be applied on the graph over and over again until none the :class:`RewriteDatabaseQuery` will be applied on the graph over and over again until none
of them is applicable, so keep that in mind when designing it: check of them is applicable, so keep that in mind when designing it: check
carefully that your optimization leads to a fixpoint (a point where it carefully that your optimization leads to a fixpoint (a point where it
cannot apply anymore) at which point it returns ``False`` to indicate its cannot apply anymore) at which point it returns ``False`` to indicate its
......
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import AddFeatureOptimizer, Mode from aesara.compile.mode import AddFeatureOptimizer, Mode
from aesara.graph.features import NoOutputFromInplace from aesara.graph.features import NoOutputFromInplace
from aesara.graph.optdb import OptimizationQuery, SequenceDB from aesara.graph.optdb import RewriteDatabaseQuery, SequenceDB
from aesara.tensor.math import dot, tanh from aesara.tensor.math import dot, tanh
from aesara.tensor.type import matrix from aesara.tensor.type import matrix
def test_Mode_basic(): def test_Mode_basic():
db = SequenceDB() db = SequenceDB()
mode = Mode(linker="py", optimizer=OptimizationQuery(include=None), db=db) mode = Mode(linker="py", optimizer=RewriteDatabaseQuery(include=None), db=db)
assert mode.optdb is db assert mode.optdb is db
assert str(mode).startswith("Mode(linker=py, optimizer=OptimizationQuery") assert str(mode).startswith("Mode(linker=py, optimizer=RewriteDatabaseQuery")
def test_NoOutputFromInplace(): def test_NoOutputFromInplace():
......
...@@ -15,7 +15,7 @@ from aesara.configdefaults import config ...@@ -15,7 +15,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, get_test_value from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.ifelse import ifelse from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker from aesara.link.jax import JAXLinker
from aesara.raise_op import assert_op from aesara.raise_op import assert_op
...@@ -56,7 +56,7 @@ from aesara.tensor.type import ( ...@@ -56,7 +56,7 @@ from aesara.tensor.type import (
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts) jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
...@@ -1142,7 +1142,7 @@ def test_jax_BatchedDot(): ...@@ -1142,7 +1142,7 @@ def test_jax_BatchedDot():
# A dimension mismatch should raise a TypeError for compatibility # A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)] inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts) jax_mode = Mode(JAXLinker(), opts)
aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError): with pytest.raises(TypeError):
......
...@@ -24,7 +24,7 @@ from aesara.compile.sharedvalue import SharedVariable ...@@ -24,7 +24,7 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, get_test_value from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.ifelse import ifelse from aesara.ifelse import ifelse
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
...@@ -92,7 +92,7 @@ my_multi_out.ufunc = MyMultiOut.impl ...@@ -92,7 +92,7 @@ my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2 my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2 my_multi_out.ufunc.nout = 2
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts) numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
......
...@@ -7,12 +7,12 @@ import aesara.tensor as aet ...@@ -7,12 +7,12 @@ import aesara.tensor as aet
from aesara import config from aesara import config
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.tensor.math import Max from aesara.tensor.math import Max
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts) numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
......
...@@ -14,7 +14,7 @@ from aesara.configdefaults import config ...@@ -14,7 +14,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, graph_inputs from aesara.graph.basic import Constant, Variable, graph_inputs
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.random.basic import ( from aesara.tensor.random.basic import (
bernoulli, bernoulli,
...@@ -60,7 +60,7 @@ from aesara.tensor.type import iscalar, scalar, tensor ...@@ -60,7 +60,7 @@ from aesara.tensor.type import iscalar, scalar, tensor
from tests.unittest_tools import create_aesara_param from tests.unittest_tools import create_aesara_param
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
......
...@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode ...@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import EquilibriumGraphRewriter from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import ( from aesara.tensor.random.basic import (
dirichlet, dirichlet,
...@@ -28,7 +28,7 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte ...@@ -28,7 +28,7 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
from aesara.tensor.type import iscalar, vector from aesara.tensor.type import iscalar, vector
no_mode = Mode("py", OptimizationQuery(include=[], exclude=[])) no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[]))
def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None): def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None):
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
from aesara import config, function from aesara import config, function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.tensor.random.utils import RandomStream, broadcast_params from aesara.tensor.random.utils import RandomStream, broadcast_params
from aesara.tensor.type import matrix, tensor from aesara.tensor.type import matrix, tensor
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -11,7 +11,7 @@ from tests import unittest_tools as utt ...@@ -11,7 +11,7 @@ from tests import unittest_tools as utt
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def set_aesara_flags(): def set_aesara_flags():
opts = OptimizationQuery(include=[None], exclude=[]) opts = RewriteDatabaseQuery(include=[None], exclude=[])
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
with config.change_flags(mode=py_mode, compute_test_value="warn"): with config.change_flags(mode=py_mode, compute_test_value="warn"):
yield yield
......
...@@ -18,7 +18,7 @@ from aesara.graph.fg import FunctionGraph ...@@ -18,7 +18,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, node_rewriter, out2in from aesara.graph.opt import check_stack_trace, node_rewriter, out2in
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint from aesara.printing import pprint
...@@ -141,15 +141,15 @@ mode_opt = get_mode(mode_opt) ...@@ -141,15 +141,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift = out2in(local_dimshuffle_lift) dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = OptimizationQuery(include=["fast_run"]) _optimizer_stabilize = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_stabilize.position_cutoff = 1.51 _optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = optdb.query(_optimizer_stabilize) _optimizer_stabilize = optdb.query(_optimizer_stabilize)
_optimizer_specialize = OptimizationQuery(include=["fast_run"]) _optimizer_specialize = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_specialize.position_cutoff = 2.01 _optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = optdb.query(_optimizer_specialize) _optimizer_specialize = optdb.query(_optimizer_specialize)
_optimizer_fast_run = OptimizationQuery(include=["fast_run"]) _optimizer_fast_run = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_fast_run = optdb.query(_optimizer_fast_run) _optimizer_fast_run = optdb.query(_optimizer_fast_run)
...@@ -352,7 +352,7 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -352,7 +352,7 @@ def test_local_useless_dimshuffle_in_reshape():
class TestFusion: class TestFusion:
opts = OptimizationQuery( opts = RewriteDatabaseQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
...@@ -1099,7 +1099,7 @@ class TestFusion: ...@@ -1099,7 +1099,7 @@ class TestFusion:
def test_add_mul_fusion_inplace(self): def test_add_mul_fusion_inplace(self):
opts = OptimizationQuery( opts = RewriteDatabaseQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
...@@ -1165,7 +1165,7 @@ class TestFusion: ...@@ -1165,7 +1165,7 @@ class TestFusion:
""" """
opts = OptimizationQuery( opts = RewriteDatabaseQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
......
...@@ -9,7 +9,7 @@ from aesara import tensor as at ...@@ -9,7 +9,7 @@ from aesara import tensor as at
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, applys_between from aesara.graph.basic import Constant, applys_between
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.extra_ops import ( from aesara.tensor.extra_ops import (
...@@ -1285,7 +1285,7 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1285,7 +1285,7 @@ class TestBroadcastTo(utt.InferShapeTester):
q = b[np.r_[0, 1, 3]] q = b[np.r_[0, 1, 3]]
e = at.set_subtensor(q, np.r_[0, 0, 0]) e = at.set_subtensor(q, np.r_[0, 0, 0])
opts = OptimizationQuery(include=["inplace"]) opts = RewriteDatabaseQuery(include=["inplace"])
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
e_fn = function([d], e, mode=py_mode) e_fn = function([d], e, mode=py_mode)
......
...@@ -26,7 +26,7 @@ from aesara.graph.opt import ( ...@@ -26,7 +26,7 @@ from aesara.graph.opt import (
out2in, out2in,
) )
from aesara.graph.opt_utils import is_same_graph, optimize_graph from aesara.graph.opt_utils import is_same_graph, optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join, switch from aesara.tensor.basic import Alloc, join, switch
...@@ -132,15 +132,15 @@ mode_opt = get_mode(mode_opt) ...@@ -132,15 +132,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift = out2in(local_dimshuffle_lift) dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = OptimizationQuery(include=["fast_run"]) _optimizer_stabilize = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_stabilize.position_cutoff = 1.51 _optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = optdb.query(_optimizer_stabilize) _optimizer_stabilize = optdb.query(_optimizer_stabilize)
_optimizer_specialize = OptimizationQuery(include=["fast_run"]) _optimizer_specialize = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_specialize.position_cutoff = 2.01 _optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = optdb.query(_optimizer_specialize) _optimizer_specialize = optdb.query(_optimizer_specialize)
_optimizer_fast_run = OptimizationQuery(include=["fast_run"]) _optimizer_fast_run = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_fast_run = optdb.query(_optimizer_fast_run) _optimizer_fast_run = optdb.query(_optimizer_fast_run)
...@@ -366,7 +366,7 @@ class TestAlgebraicCanonizer: ...@@ -366,7 +366,7 @@ class TestAlgebraicCanonizer:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other # We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = OptimizationQuery(["canonicalize"]) opt = RewriteDatabaseQuery(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion") opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases): for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
...@@ -500,7 +500,7 @@ class TestAlgebraicCanonizer: ...@@ -500,7 +500,7 @@ class TestAlgebraicCanonizer:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other # We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
mode._optimizer = OptimizationQuery(["canonicalize"]) mode._optimizer = RewriteDatabaseQuery(["canonicalize"])
mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion") mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion")
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases): for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
f = function( f = function(
...@@ -547,7 +547,7 @@ class TestAlgebraicCanonizer: ...@@ -547,7 +547,7 @@ class TestAlgebraicCanonizer:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = OptimizationQuery(["canonicalize"]) opt = RewriteDatabaseQuery(["canonicalize"])
opt = opt.including("ShapeOpt", "local_fill_to_alloc") opt = opt.including("ShapeOpt", "local_fill_to_alloc")
opt = opt.excluding("local_elemwise_fusion") opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
...@@ -907,7 +907,7 @@ class TestAlgebraicCanonizer: ...@@ -907,7 +907,7 @@ class TestAlgebraicCanonizer:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = OptimizationQuery(["canonicalize"]) opt = RewriteDatabaseQuery(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion") opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test fail! # test fail!
...@@ -1074,7 +1074,7 @@ def test_cast_in_mul_canonizer(): ...@@ -1074,7 +1074,7 @@ def test_cast_in_mul_canonizer():
class TestFusion: class TestFusion:
opts = OptimizationQuery( opts = RewriteDatabaseQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
...@@ -1782,7 +1782,7 @@ class TestFusion: ...@@ -1782,7 +1782,7 @@ class TestFusion:
def test_add_mul_fusion_inplace(self): def test_add_mul_fusion_inplace(self):
opts = OptimizationQuery( opts = RewriteDatabaseQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
......
...@@ -12,7 +12,7 @@ from aesara.configdefaults import config ...@@ -12,7 +12,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, ancestors from aesara.graph.basic import Constant, Variable, ancestors
from aesara.graph.opt import check_stack_trace from aesara.graph.opt import check_stack_trace
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.tensor import inplace from aesara.tensor import inplace
...@@ -1994,7 +1994,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): ...@@ -1994,7 +1994,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx] y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape) assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
opts = OptimizationQuery(include=[None]) opts = RewriteDatabaseQuery(include=[None])
no_opt_mode = Mode(optimizer=opts) no_opt_mode = Mode(optimizer=opts)
y_val_fn = function([x] + list(s), y, on_unused_input="ignore", mode=no_opt_mode) y_val_fn = function([x] + list(s), y, on_unused_input="ignore", mode=no_opt_mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论