提交 c2672deb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename Query to OptimizationQuery and DB to OptimizationDatabase,

上级 1c11dd44
......@@ -17,7 +17,13 @@ from aesara.graph.opt import (
MergeOptimizer,
NavigatorOptimizer,
)
from aesara.graph.optdb import EquilibriumDB, LocalGroupDB, Query, SequenceDB, TopoDB
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationQuery,
SequenceDB,
TopoDB,
)
from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.link.jax.linker import JAXLinker
......@@ -58,19 +64,21 @@ def register_linker(name, linker):
exclude = []
if not config.cxx:
exclude = ["cxx_only"]
OPT_NONE = Query(include=[], exclude=exclude)
OPT_NONE = OptimizationQuery(include=[], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
OPT_MERGE = Query(include=["merge"], exclude=exclude)
OPT_FAST_RUN = Query(include=["fast_run"], exclude=exclude)
OPT_MERGE = OptimizationQuery(include=["merge"], exclude=exclude)
OPT_FAST_RUN = OptimizationQuery(include=["fast_run"], exclude=exclude)
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
# We need fast_compile_gpu here. As on the GPU, we don't have all
# operation that exist in fast_compile, but have some that get
# introduced in fast_run, we want those optimization to also run in
# fast_compile+gpu. We can't tag them just as 'gpu', as this would
# exclude them if we exclude 'gpu'.
OPT_FAST_COMPILE = Query(include=["fast_compile", "fast_compile_gpu"], exclude=exclude)
OPT_STABILIZE = Query(include=["fast_run"], exclude=exclude)
OPT_FAST_COMPILE = OptimizationQuery(
include=["fast_compile", "fast_compile_gpu"], exclude=exclude
)
OPT_STABILIZE = OptimizationQuery(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE"
OPT_MERGE.name = "OPT_MERGE"
......@@ -297,7 +305,7 @@ class Mode:
# self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined
# Query, then this stores the query
# OptimizationQuery, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
......@@ -316,7 +324,7 @@ class Mode:
self.linker = linker
if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, Query):
if isinstance(optimizer, OptimizationQuery):
self.provided_optimizer = optimizer
self._optimizer = optimizer
self.call_time = 0
......@@ -330,7 +338,7 @@ class Mode:
)
def __get_optimizer(self):
if isinstance(self._optimizer, Query):
if isinstance(self._optimizer, OptimizationQuery):
return optdb.query(self._optimizer)
else:
return self._optimizer
......@@ -348,7 +356,7 @@ class Mode:
link, opt = self.get_linker_optimizer(
self.provided_linker, self.provided_optimizer
)
# N.B. opt might be a Query instance, not sure what else it might be...
# N.B. opt might be a OptimizationQuery instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
return self.clone(optimizer=opt.including(*tags), linker=link)
......@@ -421,9 +429,13 @@ if config.cxx:
else:
FAST_RUN = Mode("vm", "fast_run")
JAX = Mode(JAXLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]))
JAX = Mode(
JAXLinker(),
OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
)
NUMBA = Mode(
NumbaLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
NumbaLinker(),
OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
)
......
from aesara.compile import optdb
from aesara.graph.opt import GraphToGPULocalOptGroup, TopoOptimizer, local_optimizer
from aesara.graph.optdb import DB, EquilibriumDB, LocalGroupDB, SequenceDB
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationDatabase,
SequenceDB,
)
gpu_optimizer = EquilibriumDB()
......@@ -62,7 +67,7 @@ def register_opt2(tracks, *tags, **kwargs):
def f(local_opt):
name = (kwargs and kwargs.pop("name")) or local_opt.__name__
if isinstance(local_opt, DB):
if isinstance(local_opt, OptimizationDatabase):
opt = local_opt
else:
opt = local_optimizer(tracks)(local_opt)
......@@ -97,7 +102,7 @@ abstractconv_groupopt.__name__ = "gpuarray_abstractconv_opts"
register_opt("fast_compile")(abstractconv_groupopt)
class GraphToGPUDB(DB):
class GraphToGPUDB(OptimizationDatabase):
"""
Retrieves the list local optimizers based on the optimizer flag's value
from EquilibriumOptimizer by calling the method query.
......
......@@ -4,7 +4,7 @@ from typing import Sequence, Union
import aesara
from aesara.graph.basic import Variable, equal_computations, graph_inputs, vars_between
from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
def optimize_graph(
......@@ -28,7 +28,7 @@ def optimize_graph(
clone:
Whether or not to clone the input graph before optimizing.
**kwargs:
Keyword arguments passed to the ``aesara.graph.optdb.Query`` object.
Keyword arguments passed to the ``aesara.graph.optdb.OptimizationQuery`` object.
"""
from aesara.compile import optdb
......@@ -37,7 +37,7 @@ def optimize_graph(
fgraph = FunctionGraph(outputs=[fgraph], clone=clone)
return_only_out = True
canonicalize_opt = optdb.query(Query(include=include, **kwargs))
canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph)
if custom_opt:
......
......@@ -10,7 +10,13 @@ from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict
class DB:
class OptimizationDatabase:
"""A class that represents a collection/database of optimizations.
These databases can be used to logically organize sets of
(i.e. ``GlobalOptimizer``s and ``LocalOptimizer``)
"""
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0]
......@@ -24,7 +30,7 @@ class DB:
# (via obj.name by the thing doing the registering)
def register(self, name, obj, *tags, **kwargs):
"""
"""Register a new optimizer to the database.
Parameters
----------
......@@ -35,19 +41,21 @@ class DB:
tags
Tag name that allow to select the optimizer.
kwargs
If non empty, should contain only use_db_name_as_tag=False.
By default, all optimizations registered in EquilibriumDB
are selected when the EquilibriumDB name is used as a
tag. We do not want this behavior for some optimizer like
local_remove_all_assert. use_db_name_as_tag=False remove
that behavior. This mean only the optimizer name and the
tags specified will enable that optimization.
If non empty, should contain only ``use_db_name_as_tag=False``. By
default, all optimizations registered in ``EquilibriumDB`` are
selected when the ``EquilibriumDB`` name is used as a tag. We do
not want this behavior for some optimizer like
``local_remove_all_assert``. ``use_db_name_as_tag=False`` removes
that behavior. This mean only the optimizer name and the tags
specified will enable that optimization.
"""
# N.B. obj is not an instance of class `GlobalOptimizer`.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if not isinstance(obj, (DB, opt.GlobalOptimizer, opt.LocalOptimizer)):
if not isinstance(
obj, (OptimizationDatabase, opt.GlobalOptimizer, opt.LocalOptimizer)
):
raise TypeError("Object cannot be registered in OptDB", obj)
if name in self.__db__:
raise ValueError(
......@@ -99,8 +107,8 @@ class DB:
self.__db__[tag].remove(obj)
def __query__(self, q):
if not isinstance(q, Query):
raise TypeError("Expected a Query.", q)
if not isinstance(q, OptimizationQuery):
raise TypeError("Expected a OptimizationQuery.", q)
# The ordered set is needed for deterministic optimization.
variables = OrderedSet()
for tag in q.include:
......@@ -112,7 +120,7 @@ class DB:
remove = OrderedSet()
add = OrderedSet()
for obj in variables:
if isinstance(obj, DB):
if isinstance(obj, OptimizationDatabase):
def_sub_query = q
if q.extra_optimizations:
def_sub_query = copy.copy(q)
......@@ -128,10 +136,10 @@ class DB:
return variables
def query(self, *tags, **kwtags):
if len(tags) >= 1 and isinstance(tags[0], Query):
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
if len(tags) > 1 or kwtags:
raise TypeError(
"If the first argument to query is a Query,"
"If the first argument to query is a OptimizationQuery,"
" there should be no other arguments.",
tags,
kwtags,
......@@ -147,7 +155,9 @@ class DB:
tags,
)
return self.__query__(
Query(include=include, require=require, exclude=exclude, subquery=kwtags)
OptimizationQuery(
include=include, require=require, exclude=exclude, subquery=kwtags
)
)
def __getitem__(self, name):
......@@ -168,7 +178,11 @@ class DB:
print(" db", self.__db__, file=stream)
class Query:
# This is deprecated and will be removed.
DB = OptimizationDatabase
class OptimizationQuery:
"""
Parameters
......@@ -204,7 +218,7 @@ class Query:
def __str__(self):
return (
"Query{inc=%s,ex=%s,require=%s,subquery=%s,"
"OptimizationQuery{inc=%s,ex=%s,require=%s,subquery=%s,"
"position_cutoff=%f,extra_opts=%s}"
% (
self.include,
......@@ -223,7 +237,7 @@ class Query:
# add all opt with this tag
def including(self, *tags):
return Query(
return OptimizationQuery(
self.include.union(tags),
self.require,
self.exclude,
......@@ -234,7 +248,7 @@ class Query:
# remove all opt with this tag
def excluding(self, *tags):
return Query(
return OptimizationQuery(
self.include,
self.require,
self.exclude.union(tags),
......@@ -245,7 +259,7 @@ class Query:
# keep only opt with this tag.
def requiring(self, *tags):
return Query(
return OptimizationQuery(
self.include,
self.require.union(tags),
self.exclude,
......@@ -255,7 +269,7 @@ class Query:
)
def register(self, *optimizations):
return Query(
return OptimizationQuery(
self.include,
self.require,
self.exclude,
......@@ -265,7 +279,11 @@ class Query:
)
class EquilibriumDB(DB):
# This is deprecated and will be removed.
Query = OptimizationQuery
class EquilibriumDB(OptimizationDatabase):
"""
A set of potential optimizations which should be applied in an arbitrary
order until equilibrium is reached.
......@@ -331,7 +349,7 @@ class EquilibriumDB(DB):
)
class SequenceDB(DB):
class SequenceDB(OptimizationDatabase):
"""
A sequence of potential optimizations.
......@@ -378,13 +396,13 @@ class SequenceDB(DB):
position_cutoff = kwtags.pop("position_cutoff", config.optdb__position_cutoff)
position_dict = self.__position__
if len(tags) >= 1 and isinstance(tags[0], Query):
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
# the call to super should have raise an error with a good message
assert len(tags) == 1
if getattr(tags[0], "position_cutoff", None):
position_cutoff = tags[0].position_cutoff
# The Query instance might contain extra optimizations which need
# The OptimizationQuery instance might contain extra optimizations which need
# to be added the the sequence of optimizations (don't alter the
# original dictionary)
if len(tags[0].extra_optimizations) > 0:
......@@ -430,7 +448,7 @@ class SequenceDB(DB):
return sio.getvalue()
class LocalGroupDB(DB):
class LocalGroupDB(OptimizationDatabase):
"""
Generate a local optimizer of type LocalOptGroup instead
of a global optimizer.
......@@ -476,7 +494,7 @@ class LocalGroupDB(DB):
return ret
class TopoDB(DB):
class TopoDB(OptimizationDatabase):
"""
Generate a `GlobalOptimizer` of type TopoOptimizer.
......@@ -501,7 +519,7 @@ class TopoDB(DB):
)
class ProxyDB(DB):
class ProxyDB(OptimizationDatabase):
"""
Wrap an existing proxy.
......@@ -511,7 +529,7 @@ class ProxyDB(DB):
"""
def __init__(self, db):
assert isinstance(db, DB), ""
assert isinstance(db, OptimizationDatabase), ""
self.db = db
def query(self, *tags, **kwtags):
......
......@@ -395,97 +395,96 @@ Definition of optdb
optdb is an object which is an instance of
:class:`SequenceDB <optdb.SequenceDB>`,
itself a subclass of :class:`DB <optdb.DB>`.
There exist (for now) two types of DB, SequenceDB and EquilibriumDB.
When given an appropriate Query, DB objects build an Optimizer matching
itself a subclass of :class:`OptimizationDatabase <optdb.OptimizationDatabase>`.
There exist (for now) two types of OptimizationDatabase, SequenceDB and EquilibriumDB.
When given an appropriate OptimizationQuery, OptimizationDatabase objects build an Optimizer matching
the query.
A SequenceDB contains Optimizer or DB objects. Each of them has a
name, an arbitrary number of tags and an integer representing their
order in the sequence. When a Query is applied to a SequenceDB, all
Optimizers whose tags match the query are inserted in proper order in
a SequenceOptimizer, which is returned. If the SequenceDB contains DB
instances, the Query will be passed to them as well and the optimizers
they return will be put in their places.
A SequenceDB contains Optimizer or OptimizationDatabase objects. Each of them
has a name, an arbitrary number of tags and an integer representing their order
in the sequence. When a OptimizationQuery is applied to a SequenceDB, all Optimizers whose
tags match the query are inserted in proper order in a SequenceOptimizer, which
is returned. If the SequenceDB contains OptimizationDatabase instances, the OptimizationQuery will be passed
to them as well and the optimizers they return will be put in their places.
An EquilibriumDB contains LocalOptimizer or DB objects. Each of them
has a name and an arbitrary number of tags. When a Query is applied to
An EquilibriumDB contains LocalOptimizer or OptimizationDatabase objects. Each of them
has a name and an arbitrary number of tags. When a OptimizationQuery is applied to
an EquilibriumDB, all LocalOptimizers that match the query are
inserted into an EquilibriumOptimizer, which is returned. If the
SequenceDB contains DB instances, the Query will be passed to them as
SequenceDB contains OptimizationDatabase instances, the OptimizationQuery will be passed to them as
well and the LocalOptimizers they return will be put in their places
(note that as of yet no DB can produce LocalOptimizer objects, so this
(note that as of yet no OptimizationDatabase can produce LocalOptimizer objects, so this
is a moot point).
Aesara contains one principal DB object, :class:`optdb`, which
Aesara contains one principal OptimizationDatabase object, :class:`optdb`, which
contains all of Aesara's optimizers with proper tags. It is
recommended to insert new Optimizers in it. As mentioned previously,
optdb is a SequenceDB, so, at the top level, Aesara applies a sequence
of global optimizations to the computation graphs.
Query
OptimizationQuery
-----
A Query is built by the following call:
A OptimizationQuery is built by the following call:
.. code-block:: python
aesara.graph.optdb.Query(include, require=None, exclude=None, subquery=None)
aesara.graph.optdb.OptimizationQuery(include, require=None, exclude=None, subquery=None)
.. class:: Query
.. class:: OptimizationQuery
.. attribute:: include
A set of tags (a tag being a string) such that every
optimization obtained through this Query must have **one** of the tags
optimization obtained through this OptimizationQuery must have **one** of the tags
listed. This field is required and basically acts as a starting point
for the search.
.. attribute:: require
A set of tags such that every optimization obtained
through this Query must have **all** of these tags.
through this OptimizationQuery must have **all** of these tags.
.. attribute:: exclude
A set of tags such that every optimization obtained
through this Query must have **none** of these tags.
through this OptimizationQuery must have **none** of these tags.
.. attribute:: subquery
optdb can contain sub-databases; subquery is a
dictionary mapping the name of a sub-database to a special Query.
If no subquery is given for a sub-database, the original Query will be
dictionary mapping the name of a sub-database to a special OptimizationQuery.
If no subquery is given for a sub-database, the original OptimizationQuery will be
used again.
Furthermore, a Query object includes three methods, ``including``,
``requiring`` and ``excluding`` which each produce a new Query object
Furthermore, a OptimizationQuery object includes three methods, ``including``,
``requiring`` and ``excluding`` which each produce a new OptimizationQuery object
with include, require and exclude sets refined to contain the new [WRITEME]
Examples
--------
Here are a few examples of how to use a Query on optdb to produce an
Here are a few examples of how to use a OptimizationQuery on optdb to produce an
Optimizer:
.. testcode::
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
from aesara.compile import optdb
# This is how the optimizer for the fast_run mode is defined
fast_run = optdb.query(Query(include=['fast_run']))
fast_run = optdb.query(OptimizationQuery(include=['fast_run']))
# This is how the optimizer for the fast_compile mode is defined
fast_compile = optdb.query(Query(include=['fast_compile']))
fast_compile = optdb.query(OptimizationQuery(include=['fast_compile']))
# This is the same as fast_run but no optimizations will replace
# any operation by an inplace version. This assumes, of course,
# that all inplace operations are tagged as 'inplace' (as they
# should!)
fast_run_no_inplace = optdb.query(Query(include=['fast_run'],
fast_run_no_inplace = optdb.query(OptimizationQuery(include=['fast_run'],
exclude=['inplace']))
......@@ -544,7 +543,7 @@ optimizations:
For each group, all optimizations of the group that are selected by
the Query will be applied on the graph over and over again until none
the OptimizationQuery will be applied on the graph over and over again until none
of them is applicable, so keep that in mind when designing it: check
carefully that your optimization leads to a fixpoint (a point where it
cannot apply anymore) at which point it returns ``False`` to indicate its
......
import pytest
from aesara.graph.optdb import DB, opt
from aesara.graph.optdb import OptimizationDatabase, opt
class TestDB:
......@@ -11,7 +11,7 @@ class TestDB:
def apply(self, fgraph):
pass
db = DB()
db = OptimizationDatabase()
db.register("a", Opt())
db.register("b", Opt())
......
......@@ -13,7 +13,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker
from aesara.scalar.basic import Composite
......@@ -52,7 +52,7 @@ from aesara.tensor.type import (
jax = pytest.importorskip("jax")
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
......@@ -1057,7 +1057,7 @@ def test_jax_BatchedDot():
# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError):
......
......@@ -21,7 +21,7 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
from aesara.link.numba.dispatch import create_numba_signature, get_numba_type
from aesara.link.numba.linker import NumbaLinker
......@@ -70,7 +70,7 @@ class MyMultiOut(Op):
outputs[1][0] = res2
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
......
......@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode
from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import EquilibriumOptimizer
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import (
dirichlet,
......@@ -27,8 +27,10 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
from aesara.tensor.type import iscalar, vector
inplace_mode = Mode("py", Query(include=["random_make_inplace"], exclude=[]))
no_mode = Mode("py", Query(include=[], exclude=[]))
inplace_mode = Mode(
"py", OptimizationQuery(include=["random_make_inplace"], exclude=[])
)
no_mode = Mode("py", OptimizationQuery(include=[], exclude=[]))
def test_inplace_optimization():
......
......@@ -3,7 +3,7 @@ import pytest
from aesara import config, function
from aesara.compile.mode import Mode
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.random.utils import RandomStream, broadcast_params
from aesara.tensor.type import matrix, tensor
from tests import unittest_tools as utt
......@@ -11,7 +11,7 @@ from tests import unittest_tools as utt
@pytest.fixture(scope="module", autouse=True)
def set_aesara_flags():
opts = Query(include=[None], exclude=[])
opts = OptimizationQuery(include=[None], exclude=[])
py_mode = Mode("py", opts)
with config.change_flags(mode=py_mode, compute_test_value="warn"):
yield
......
......@@ -19,7 +19,7 @@ from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, local_optimizer, out2in
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace
from aesara.tensor.basic import (
......@@ -140,15 +140,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = Query(include=["fast_run"])
_optimizer_stabilize = OptimizationQuery(include=["fast_run"])
_optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = optdb.query(_optimizer_stabilize)
_optimizer_specialize = Query(include=["fast_run"])
_optimizer_specialize = OptimizationQuery(include=["fast_run"])
_optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = optdb.query(_optimizer_specialize)
_optimizer_fast_run = Query(include=["fast_run"])
_optimizer_fast_run = OptimizationQuery(include=["fast_run"])
_optimizer_fast_run = optdb.query(_optimizer_fast_run)
......@@ -351,7 +351,7 @@ def test_local_useless_dimshuffle_in_reshape():
class TestFusion:
opts = Query(
opts = OptimizationQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......@@ -1125,7 +1125,7 @@ class TestFusion:
def test_add_mul_fusion_inplace(self):
opts = Query(
opts = OptimizationQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......
......@@ -9,7 +9,7 @@ from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.gradient import grad
from aesara.graph.basic import applys_between
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.extra_ops import (
Bartlett,
......@@ -1169,7 +1169,7 @@ class TestBroadcastTo(utt.InferShapeTester):
q = b[np.r_[0, 1, 3]]
e = aet.set_subtensor(q, np.r_[0, 0, 0])
opts = Query(include=["inplace"])
opts = OptimizationQuery(include=["inplace"])
py_mode = Mode("py", opts)
e_fn = function([d], e, mode=py_mode)
......
......@@ -20,7 +20,7 @@ from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in
from aesara.graph.opt_utils import is_same_graph
from aesara.graph.optdb import Query
from aesara.graph.optdb import OptimizationQuery
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join, switch
......@@ -124,15 +124,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = Query(include=["fast_run"])
_optimizer_stabilize = OptimizationQuery(include=["fast_run"])
_optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = optdb.query(_optimizer_stabilize)
_optimizer_specialize = Query(include=["fast_run"])
_optimizer_specialize = OptimizationQuery(include=["fast_run"])
_optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = optdb.query(_optimizer_specialize)
_optimizer_fast_run = Query(include=["fast_run"])
_optimizer_fast_run = OptimizationQuery(include=["fast_run"])
_optimizer_fast_run = optdb.query(_optimizer_fast_run)
......@@ -351,7 +351,7 @@ class TestAlgebraicCanonize:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode()
opt = Query(["canonicalize"])
opt = OptimizationQuery(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt)
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
......@@ -486,7 +486,7 @@ class TestAlgebraicCanonize:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode()
mode._optimizer = Query(["canonicalize"])
mode._optimizer = OptimizationQuery(["canonicalize"])
mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion")
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
f = function(
......@@ -534,7 +534,7 @@ class TestAlgebraicCanonize:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode()
opt = Query(["canonicalize"])
opt = OptimizationQuery(["canonicalize"])
opt = opt.including("ShapeOpt", "local_fill_to_alloc")
opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt)
......@@ -897,7 +897,7 @@ class TestAlgebraicCanonize:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode()
opt = Query(["canonicalize"])
opt = OptimizationQuery(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test fail!
......@@ -1051,7 +1051,7 @@ def test_cast_in_mul_canonizer():
class TestFusion:
opts = Query(
opts = OptimizationQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......@@ -1762,7 +1762,7 @@ class TestFusion:
def test_add_mul_fusion_inplace(self):
opts = Query(
opts = OptimizationQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论