提交 92a3b2a6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename OptimizationQuery to RewriteDatabaseQuery

上级 5dbfd046
......@@ -19,8 +19,8 @@ from aesara.graph.opt import (
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationQuery,
RewriteDatabase,
RewriteDatabaseQuery,
SequenceDB,
TopoDB,
)
......@@ -64,15 +64,15 @@ def register_linker(name, linker):
exclude = []
if not config.cxx:
exclude = ["cxx_only"]
OPT_NONE = OptimizationQuery(include=[], exclude=exclude)
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
OPT_MERGE = OptimizationQuery(include=["merge"], exclude=exclude)
OPT_FAST_RUN = OptimizationQuery(include=["fast_run"], exclude=exclude)
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
OPT_FAST_COMPILE = OptimizationQuery(include=["fast_compile"], exclude=exclude)
OPT_STABILIZE = OptimizationQuery(include=["fast_run"], exclude=exclude)
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude)
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE"
OPT_MERGE.name = "OPT_MERGE"
......@@ -302,7 +302,7 @@ class Mode:
def __init__(
self,
linker: Optional[Union[str, Linker]] = None,
optimizer: Union[str, OptimizationQuery] = "default",
optimizer: Union[str, RewriteDatabaseQuery] = "default",
db: RewriteDatabase = None,
):
if linker is None:
......@@ -320,7 +320,7 @@ class Mode:
# self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined
# OptimizationQuery, then this stores the query
# RewriteDatabaseQuery, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
......@@ -342,7 +342,7 @@ class Mode:
self.linker = linker
if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, OptimizationQuery):
if isinstance(optimizer, RewriteDatabaseQuery):
self.provided_optimizer = optimizer
self._optimizer = optimizer
self.call_time = 0
......@@ -357,7 +357,7 @@ class Mode:
)
def __get_optimizer(self):
if isinstance(self._optimizer, OptimizationQuery):
if isinstance(self._optimizer, RewriteDatabaseQuery):
return self.optdb.query(self._optimizer)
else:
return self._optimizer
......@@ -375,7 +375,7 @@ class Mode:
link, opt = self.get_linker_optimizer(
self.provided_linker, self.provided_optimizer
)
# N.B. opt might be a OptimizationQuery instance, not sure what else it might be...
# N.B. opt might be a RewriteDatabaseQuery instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
return self.clone(optimizer=opt.including(*tags), linker=link)
......@@ -448,11 +448,11 @@ else:
JAX = Mode(
JAXLinker(),
OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
)
NUMBA = Mode(
NumbaLinker(),
OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
)
......
......@@ -15,6 +15,6 @@ from aesara.graph.type import Type
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import node_rewriter, graph_rewriter
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
# isort: on
......@@ -10,7 +10,7 @@ from aesara.graph.basic import (
vars_between,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
def optimize_graph(
......@@ -34,7 +34,7 @@ def optimize_graph(
clone:
Whether or not to clone the input graph before optimizing.
**kwargs:
Keyword arguments passed to the ``aesara.graph.optdb.OptimizationQuery`` object.
Keyword arguments passed to the ``aesara.graph.optdb.RewriteDatabaseQuery`` object.
"""
from aesara.compile import optdb
......@@ -43,7 +43,7 @@ def optimize_graph(
fgraph = FunctionGraph(outputs=[fgraph], clone=clone)
return_only_out = True
canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs))
canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph)
if custom_opt:
......
......@@ -137,10 +137,10 @@ class RewriteDatabase:
return variables
def query(self, *tags, **kwtags):
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
if len(tags) >= 1 and isinstance(tags[0], RewriteDatabaseQuery):
if len(tags) > 1 or kwtags:
raise TypeError(
"If the first argument to query is an `OptimizationQuery`,"
"If the first argument to query is an `RewriteDatabaseQuery`,"
" there should be no other arguments."
)
return self.__query__(tags[0])
......@@ -153,7 +153,7 @@ class RewriteDatabase:
" characters: '+', '&' or '-'"
)
return self.__query__(
OptimizationQuery(
RewriteDatabaseQuery(
include=include, require=require, exclude=exclude, subquery=kwtags
)
)
......@@ -176,7 +176,7 @@ class RewriteDatabase:
print(" db", self.__db__, file=stream)
class OptimizationQuery:
class RewriteDatabaseQuery:
"""An object that specifies a set of optimizations by tag/name."""
def __init__(
......@@ -184,11 +184,11 @@ class OptimizationQuery:
include: Iterable[str],
require: Optional[Union[OrderedSet, Sequence[str]]] = None,
exclude: Optional[Union[OrderedSet, Sequence[str]]] = None,
subquery: Optional[Dict[str, "OptimizationQuery"]] = None,
subquery: Optional[Dict[str, "RewriteDatabaseQuery"]] = None,
position_cutoff: float = math.inf,
extra_optimizations: Optional[
Sequence[
Tuple[Union["OptimizationQuery", OptimizersType], Union[int, float]]
Tuple[Union["RewriteDatabaseQuery", OptimizersType], Union[int, float]]
]
] = None,
):
......@@ -198,19 +198,19 @@ class OptimizationQuery:
==========
include:
A set of tags such that every optimization obtained through this
`OptimizationQuery` must have **one** of the tags listed. This
`RewriteDatabaseQuery` must have **one** of the tags listed. This
field is required and basically acts as a starting point for the
search.
require:
A set of tags such that every optimization obtained through this
`OptimizationQuery` must have **all** of these tags.
`RewriteDatabaseQuery` must have **all** of these tags.
exclude:
A set of tags such that every optimization obtained through this
``OptimizationQuery` must have **none** of these tags.
``RewriteDatabaseQuery` must have **none** of these tags.
subquery:
A dictionary mapping the name of a sub-database to a special
`OptimizationQuery`. If no subquery is given for a sub-database,
the original `OptimizationQuery` will be used again.
`RewriteDatabaseQuery`. If no subquery is given for a sub-database,
the original `RewriteDatabaseQuery` will be used again.
position_cutoff:
Only optimizations with position less than the cutoff are returned.
extra_optimizations:
......@@ -229,7 +229,7 @@ class OptimizationQuery:
def __str__(self):
return (
"OptimizationQuery("
"RewriteDatabaseQuery("
+ f"inc={self.include},ex={self.exclude},"
+ f"require={self.require},subquery={self.subquery},"
+ f"position_cutoff={self.position_cutoff},"
......@@ -241,9 +241,9 @@ class OptimizationQuery:
if not hasattr(self, "extra_optimizations"):
self.extra_optimizations = []
def including(self, *tags: str) -> "OptimizationQuery":
def including(self, *tags: str) -> "RewriteDatabaseQuery":
"""Add rewrites with the given tags."""
return OptimizationQuery(
return RewriteDatabaseQuery(
self.include.union(tags),
self.require,
self.exclude,
......@@ -252,9 +252,9 @@ class OptimizationQuery:
self.extra_optimizations,
)
def excluding(self, *tags: str) -> "OptimizationQuery":
def excluding(self, *tags: str) -> "RewriteDatabaseQuery":
"""Remove rewrites with the given tags."""
return OptimizationQuery(
return RewriteDatabaseQuery(
self.include,
self.require,
self.exclude.union(tags),
......@@ -263,9 +263,9 @@ class OptimizationQuery:
self.extra_optimizations,
)
def requiring(self, *tags: str) -> "OptimizationQuery":
def requiring(self, *tags: str) -> "RewriteDatabaseQuery":
"""Filter for rewrites with the given tags."""
return OptimizationQuery(
return RewriteDatabaseQuery(
self.include,
self.require.union(tags),
self.exclude,
......@@ -275,10 +275,10 @@ class OptimizationQuery:
)
def register(
self, *optimizations: Tuple["OptimizationQuery", Union[int, float]]
) -> "OptimizationQuery":
self, *optimizations: Tuple["RewriteDatabaseQuery", Union[int, float]]
) -> "RewriteDatabaseQuery":
"""Include the given optimizations."""
return OptimizationQuery(
return RewriteDatabaseQuery(
self.include,
self.require,
self.exclude,
......@@ -417,13 +417,13 @@ class SequenceDB(RewriteDatabase):
position_dict = self.__position__
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
if len(tags) >= 1 and isinstance(tags[0], RewriteDatabaseQuery):
# the call to super should have raise an error with a good message
assert len(tags) == 1
if getattr(tags[0], "position_cutoff", None):
position_cutoff = tags[0].position_cutoff
# The OptimizationQuery instance might contain extra optimizations which need
# The RewriteDatabaseQuery instance might contain extra optimizations which need
# to be added the the sequence of optimizations (don't alter the
# original dictionary)
if len(tags[0].extra_optimizations) > 0:
......@@ -544,14 +544,19 @@ DEPRECATED_NAMES = [
),
(
"Query",
"`Query` is deprecated; use `OptimizationQuery` instead.",
OptimizationQuery,
"`Query` is deprecated; use `RewriteDatabaseQuery` instead.",
RewriteDatabaseQuery,
),
(
"OptimizationDatabase",
"`OptimizationDatabase` is deprecated; use `RewriteDatabase` instead.",
RewriteDatabase,
),
(
"OptimizationQuery",
"`OptimizationQuery` is deprecated; use `RewriteDatabaseQuery` instead.",
RewriteDatabaseQuery,
),
]
......
......@@ -585,23 +585,23 @@ Definition of :obj:`optdb`
:class:`SequenceDB <optdb.SequenceDB>`,
itself a subclass of :class:`RewriteDatabase <optdb.RewriteDatabase>`.
There exist (for now) two types of :class:`RewriteDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`.
When given an appropriate :class:`OptimizationQuery`, :class:`RewriteDatabase` objects build an :class:`Optimizer` matching
When given an appropriate :class:`RewriteDatabaseQuery`, :class:`RewriteDatabase` objects build an :class:`Optimizer` matching
the query.
A :class:`SequenceDB` contains :class:`Optimizer` or :class:`RewriteDatabase` objects. Each of them
has a name, an arbitrary number of tags and an integer representing their order
in the sequence. When a :class:`OptimizationQuery` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose
in the sequence. When a :class:`RewriteDatabaseQuery` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose
tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which
is returned. If the :class:`SequenceDB` contains :class:`RewriteDatabase`
instances, the :class:`OptimizationQuery` will be passed to them as well and the
instances, the :class:`RewriteDatabaseQuery` will be passed to them as well and the
optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`RewriteDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
has a name and an arbitrary number of tags. When a :class:`RewriteDatabaseQuery` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the
:class:`SequenceDB` contains :class:`RewriteDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`RewriteDatabaseQuery` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places
(note that as of yet no :class:`RewriteDatabase` can produce :class:`NodeRewriter` objects, so this
is a moot point).
......@@ -613,68 +613,68 @@ optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence
of global optimizations to the computation graphs.
:class:`OptimizationQuery`
--------------------------
:class:`RewriteDatabaseQuery`
-----------------------------
A :class:`OptimizationQuery` is built by the following call:
A :class:`RewriteDatabaseQuery` is built by the following call:
.. code-block:: python
aesara.graph.optdb.OptimizationQuery(include, require=None, exclude=None, subquery=None)
aesara.graph.optdb.RewriteDatabaseQuery(include, require=None, exclude=None, subquery=None)
.. class:: OptimizationQuery
.. class:: RewriteDatabaseQuery
.. attribute:: include
A set of tags (a tag being a string) such that every
optimization obtained through this :class:`OptimizationQuery` must have **one** of the tags
optimization obtained through this :class:`RewriteDatabaseQuery` must have **one** of the tags
listed. This field is required and basically acts as a starting point
for the search.
.. attribute:: require
A set of tags such that every optimization obtained
through this :class:`OptimizationQuery` must have **all** of these tags.
through this :class:`RewriteDatabaseQuery` must have **all** of these tags.
.. attribute:: exclude
A set of tags such that every optimization obtained
through this :class:`OptimizationQuery` must have **none** of these tags.
through this :class:`RewriteDatabaseQuery` must have **none** of these tags.
.. attribute:: subquery
:obj:`optdb` can contain sub-databases; subquery is a
dictionary mapping the name of a sub-database to a special :class:`OptimizationQuery`.
If no subquery is given for a sub-database, the original :class:`OptimizationQuery` will be
dictionary mapping the name of a sub-database to a special :class:`RewriteDatabaseQuery`.
If no subquery is given for a sub-database, the original :class:`RewriteDatabaseQuery` will be
used again.
Furthermore, a :class:`OptimizationQuery` object includes three methods, :meth:`including`,
:meth:`requiring` and :meth:`excluding`, which each produce a new :class:`OptimizationQuery` object
Furthermore, a :class:`RewriteDatabaseQuery` object includes three methods, :meth:`including`,
:meth:`requiring` and :meth:`excluding`, which each produce a new :class:`RewriteDatabaseQuery` object
with the include, require, and exclude sets refined to contain the new entries.
Examples
--------
Here are a few examples of how to use a :class:`OptimizationQuery` on :obj:`optdb` to produce an
Here are a few examples of how to use a :class:`RewriteDatabaseQuery` on :obj:`optdb` to produce an
:class:`Optimizer`:
.. testcode::
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.compile import optdb
# This is how the optimizer for the fast_run mode is defined
fast_run = optdb.query(OptimizationQuery(include=['fast_run']))
fast_run = optdb.query(RewriteDatabaseQuery(include=['fast_run']))
# This is how the optimizer for the fast_compile mode is defined
fast_compile = optdb.query(OptimizationQuery(include=['fast_compile']))
fast_compile = optdb.query(RewriteDatabaseQuery(include=['fast_compile']))
# This is the same as fast_run but no optimizations will replace
# any operation by an inplace version. This assumes, of course,
# that all inplace operations are tagged as 'inplace' (as they
# should!)
fast_run_no_inplace = optdb.query(OptimizationQuery(include=['fast_run'],
fast_run_no_inplace = optdb.query(RewriteDatabaseQuery(include=['fast_run'],
exclude=['inplace']))
......@@ -733,7 +733,7 @@ optimizations:
For each group, all optimizations of the group that are selected by
the :class:`OptimizationQuery` will be applied on the graph over and over again until none
the :class:`RewriteDatabaseQuery` will be applied on the graph over and over again until none
of them is applicable, so keep that in mind when designing it: check
carefully that your optimization leads to a fixpoint (a point where it
cannot apply anymore) at which point it returns ``False`` to indicate its
......
from aesara.compile.function import function
from aesara.compile.mode import AddFeatureOptimizer, Mode
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.optdb import OptimizationQuery, SequenceDB
from aesara.graph.optdb import RewriteDatabaseQuery, SequenceDB
from aesara.tensor.math import dot, tanh
from aesara.tensor.type import matrix
def test_Mode_basic():
db = SequenceDB()
mode = Mode(linker="py", optimizer=OptimizationQuery(include=None), db=db)
mode = Mode(linker="py", optimizer=RewriteDatabaseQuery(include=None), db=db)
assert mode.optdb is db
assert str(mode).startswith("Mode(linker=py, optimizer=OptimizationQuery")
assert str(mode).startswith("Mode(linker=py, optimizer=RewriteDatabaseQuery")
def test_NoOutputFromInplace():
......
......@@ -15,7 +15,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker
from aesara.raise_op import assert_op
......@@ -56,7 +56,7 @@ from aesara.tensor.type import (
jax = pytest.importorskip("jax")
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
......@@ -1142,7 +1142,7 @@ def test_jax_BatchedDot():
# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError):
......
......@@ -24,7 +24,7 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type
from aesara.ifelse import ifelse
from aesara.link.numba.dispatch import basic as numba_basic
......@@ -92,7 +92,7 @@ my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
......
......@@ -7,12 +7,12 @@ import aesara.tensor as aet
from aesara import config
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.link.numba.linker import NumbaLinker
from aesara.tensor.math import Max
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
......
......@@ -14,7 +14,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.random.basic import (
bernoulli,
......@@ -60,7 +60,7 @@ from aesara.tensor.type import iscalar, scalar, tensor
from tests.unittest_tools import create_aesara_param
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
py_mode = Mode("py", opts)
......
......@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode
from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import (
dirichlet,
......@@ -28,7 +28,7 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
from aesara.tensor.type import iscalar, vector
no_mode = Mode("py", OptimizationQuery(include=[], exclude=[]))
no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[]))
def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None):
......
......@@ -3,7 +3,7 @@ import pytest
from aesara import config, function
from aesara.compile.mode import Mode
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.tensor.random.utils import RandomStream, broadcast_params
from aesara.tensor.type import matrix, tensor
from tests import unittest_tools as utt
......@@ -11,7 +11,7 @@ from tests import unittest_tools as utt
@pytest.fixture(scope="module", autouse=True)
def set_aesara_flags():
opts = OptimizationQuery(include=[None], exclude=[])
opts = RewriteDatabaseQuery(include=[None], exclude=[])
py_mode = Mode("py", opts)
with config.change_flags(mode=py_mode, compute_test_value="warn"):
yield
......
......@@ -18,7 +18,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, node_rewriter, out2in
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint
......@@ -141,15 +141,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = OptimizationQuery(include=["fast_run"])
_optimizer_stabilize = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = optdb.query(_optimizer_stabilize)
_optimizer_specialize = OptimizationQuery(include=["fast_run"])
_optimizer_specialize = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = optdb.query(_optimizer_specialize)
_optimizer_fast_run = OptimizationQuery(include=["fast_run"])
_optimizer_fast_run = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_fast_run = optdb.query(_optimizer_fast_run)
......@@ -352,7 +352,7 @@ def test_local_useless_dimshuffle_in_reshape():
class TestFusion:
opts = OptimizationQuery(
opts = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......@@ -1099,7 +1099,7 @@ class TestFusion:
def test_add_mul_fusion_inplace(self):
opts = OptimizationQuery(
opts = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......@@ -1165,7 +1165,7 @@ class TestFusion:
"""
opts = OptimizationQuery(
opts = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......
......@@ -9,7 +9,7 @@ from aesara import tensor as at
from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.graph.basic import Constant, applys_between
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.raise_op import Assert
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.extra_ops import (
......@@ -1285,7 +1285,7 @@ class TestBroadcastTo(utt.InferShapeTester):
q = b[np.r_[0, 1, 3]]
e = at.set_subtensor(q, np.r_[0, 0, 0])
opts = OptimizationQuery(include=["inplace"])
opts = RewriteDatabaseQuery(include=["inplace"])
py_mode = Mode("py", opts)
e_fn = function([d], e, mode=py_mode)
......
......@@ -26,7 +26,7 @@ from aesara.graph.opt import (
out2in,
)
from aesara.graph.opt_utils import is_same_graph, optimize_graph
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join, switch
......@@ -132,15 +132,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = OptimizationQuery(include=["fast_run"])
_optimizer_stabilize = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = optdb.query(_optimizer_stabilize)
_optimizer_specialize = OptimizationQuery(include=["fast_run"])
_optimizer_specialize = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = optdb.query(_optimizer_specialize)
_optimizer_fast_run = OptimizationQuery(include=["fast_run"])
_optimizer_fast_run = RewriteDatabaseQuery(include=["fast_run"])
_optimizer_fast_run = optdb.query(_optimizer_fast_run)
......@@ -366,7 +366,7 @@ class TestAlgebraicCanonizer:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode()
opt = OptimizationQuery(["canonicalize"])
opt = RewriteDatabaseQuery(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt)
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
......@@ -500,7 +500,7 @@ class TestAlgebraicCanonizer:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode()
mode._optimizer = OptimizationQuery(["canonicalize"])
mode._optimizer = RewriteDatabaseQuery(["canonicalize"])
mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion")
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
f = function(
......@@ -547,7 +547,7 @@ class TestAlgebraicCanonizer:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode()
opt = OptimizationQuery(["canonicalize"])
opt = RewriteDatabaseQuery(["canonicalize"])
opt = opt.including("ShapeOpt", "local_fill_to_alloc")
opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt)
......@@ -907,7 +907,7 @@ class TestAlgebraicCanonizer:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode()
opt = OptimizationQuery(["canonicalize"])
opt = RewriteDatabaseQuery(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test fail!
......@@ -1074,7 +1074,7 @@ def test_cast_in_mul_canonizer():
class TestFusion:
opts = OptimizationQuery(
opts = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......@@ -1782,7 +1782,7 @@ class TestFusion:
def test_add_mul_fusion_inplace(self):
opts = OptimizationQuery(
opts = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......
......@@ -12,7 +12,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, ancestors
from aesara.graph.opt import check_stack_trace
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type
from aesara.raise_op import Assert
from aesara.tensor import inplace
......@@ -1994,7 +1994,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
opts = OptimizationQuery(include=[None])
opts = RewriteDatabaseQuery(include=[None])
no_opt_mode = Mode(optimizer=opts)
y_val_fn = function([x] + list(s), y, on_unused_input="ignore", mode=no_opt_mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论