提交 c2672deb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename Query to OptimizationQuery and DB to OptimizationDatabase,

上级 1c11dd44
...@@ -17,7 +17,13 @@ from aesara.graph.opt import ( ...@@ -17,7 +17,13 @@ from aesara.graph.opt import (
MergeOptimizer, MergeOptimizer,
NavigatorOptimizer, NavigatorOptimizer,
) )
from aesara.graph.optdb import EquilibriumDB, LocalGroupDB, Query, SequenceDB, TopoDB from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationQuery,
SequenceDB,
TopoDB,
)
from aesara.link.basic import PerformLinker from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.link.jax.linker import JAXLinker from aesara.link.jax.linker import JAXLinker
...@@ -58,19 +64,21 @@ def register_linker(name, linker): ...@@ -58,19 +64,21 @@ def register_linker(name, linker):
exclude = [] exclude = []
if not config.cxx: if not config.cxx:
exclude = ["cxx_only"] exclude = ["cxx_only"]
OPT_NONE = Query(include=[], exclude=exclude) OPT_NONE = OptimizationQuery(include=[], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't # Even if multiple merge optimizer call will be there, this shouldn't
# impact performance. # impact performance.
OPT_MERGE = Query(include=["merge"], exclude=exclude) OPT_MERGE = OptimizationQuery(include=["merge"], exclude=exclude)
OPT_FAST_RUN = Query(include=["fast_run"], exclude=exclude) OPT_FAST_RUN = OptimizationQuery(include=["fast_run"], exclude=exclude)
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable") OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
# We need fast_compile_gpu here. As on the GPU, we don't have all # We need fast_compile_gpu here. As on the GPU, we don't have all
# operation that exist in fast_compile, but have some that get # operation that exist in fast_compile, but have some that get
# introduced in fast_run, we want those optimization to also run in # introduced in fast_run, we want those optimization to also run in
# fast_compile+gpu. We can't tag them just as 'gpu', as this would # fast_compile+gpu. We can't tag them just as 'gpu', as this would
# exclude them if we exclude 'gpu'. # exclude them if we exclude 'gpu'.
OPT_FAST_COMPILE = Query(include=["fast_compile", "fast_compile_gpu"], exclude=exclude) OPT_FAST_COMPILE = OptimizationQuery(
OPT_STABILIZE = Query(include=["fast_run"], exclude=exclude) include=["fast_compile", "fast_compile_gpu"], exclude=exclude
)
OPT_STABILIZE = OptimizationQuery(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001 OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE" OPT_NONE.name = "OPT_NONE"
OPT_MERGE.name = "OPT_MERGE" OPT_MERGE.name = "OPT_MERGE"
...@@ -297,7 +305,7 @@ class Mode: ...@@ -297,7 +305,7 @@ class Mode:
# self.provided_optimizer - typically the `optimizer` arg. # self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined # But if the `optimizer` arg is keyword corresponding to a predefined
# Query, then this stores the query # OptimizationQuery, then this stores the query
# self._optimizer - typically same as provided_optimizer?? # self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying # self.__get_optimizer - returns self._optimizer (possibly querying
...@@ -316,7 +324,7 @@ class Mode: ...@@ -316,7 +324,7 @@ class Mode:
self.linker = linker self.linker = linker
if isinstance(optimizer, str) or optimizer is None: if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer] optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, Query): if isinstance(optimizer, OptimizationQuery):
self.provided_optimizer = optimizer self.provided_optimizer = optimizer
self._optimizer = optimizer self._optimizer = optimizer
self.call_time = 0 self.call_time = 0
...@@ -330,7 +338,7 @@ class Mode: ...@@ -330,7 +338,7 @@ class Mode:
) )
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, Query): if isinstance(self._optimizer, OptimizationQuery):
return optdb.query(self._optimizer) return optdb.query(self._optimizer)
else: else:
return self._optimizer return self._optimizer
...@@ -348,7 +356,7 @@ class Mode: ...@@ -348,7 +356,7 @@ class Mode:
link, opt = self.get_linker_optimizer( link, opt = self.get_linker_optimizer(
self.provided_linker, self.provided_optimizer self.provided_linker, self.provided_optimizer
) )
# N.B. opt might be a Query instance, not sure what else it might be... # N.B. opt might be a OptimizationQuery instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows??? # string? Optimizer? OptDB? who knows???
return self.clone(optimizer=opt.including(*tags), linker=link) return self.clone(optimizer=opt.including(*tags), linker=link)
...@@ -421,9 +429,13 @@ if config.cxx: ...@@ -421,9 +429,13 @@ if config.cxx:
else: else:
FAST_RUN = Mode("vm", "fast_run") FAST_RUN = Mode("vm", "fast_run")
JAX = Mode(JAXLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])) JAX = Mode(
JAXLinker(),
OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
)
NUMBA = Mode( NUMBA = Mode(
NumbaLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) NumbaLinker(),
OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
) )
......
from aesara.compile import optdb from aesara.compile import optdb
from aesara.graph.opt import GraphToGPULocalOptGroup, TopoOptimizer, local_optimizer from aesara.graph.opt import GraphToGPULocalOptGroup, TopoOptimizer, local_optimizer
from aesara.graph.optdb import DB, EquilibriumDB, LocalGroupDB, SequenceDB from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationDatabase,
SequenceDB,
)
gpu_optimizer = EquilibriumDB() gpu_optimizer = EquilibriumDB()
...@@ -62,7 +67,7 @@ def register_opt2(tracks, *tags, **kwargs): ...@@ -62,7 +67,7 @@ def register_opt2(tracks, *tags, **kwargs):
def f(local_opt): def f(local_opt):
name = (kwargs and kwargs.pop("name")) or local_opt.__name__ name = (kwargs and kwargs.pop("name")) or local_opt.__name__
if isinstance(local_opt, DB): if isinstance(local_opt, OptimizationDatabase):
opt = local_opt opt = local_opt
else: else:
opt = local_optimizer(tracks)(local_opt) opt = local_optimizer(tracks)(local_opt)
...@@ -97,7 +102,7 @@ abstractconv_groupopt.__name__ = "gpuarray_abstractconv_opts" ...@@ -97,7 +102,7 @@ abstractconv_groupopt.__name__ = "gpuarray_abstractconv_opts"
register_opt("fast_compile")(abstractconv_groupopt) register_opt("fast_compile")(abstractconv_groupopt)
class GraphToGPUDB(DB): class GraphToGPUDB(OptimizationDatabase):
""" """
Retrieves the list local optimizers based on the optimizer flag's value Retrieves the list local optimizers based on the optimizer flag's value
from EquilibriumOptimizer by calling the method query. from EquilibriumOptimizer by calling the method query.
......
...@@ -4,7 +4,7 @@ from typing import Sequence, Union ...@@ -4,7 +4,7 @@ from typing import Sequence, Union
import aesara import aesara
from aesara.graph.basic import Variable, equal_computations, graph_inputs, vars_between from aesara.graph.basic import Variable, equal_computations, graph_inputs, vars_between
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
def optimize_graph( def optimize_graph(
...@@ -28,7 +28,7 @@ def optimize_graph( ...@@ -28,7 +28,7 @@ def optimize_graph(
clone: clone:
Whether or not to clone the input graph before optimizing. Whether or not to clone the input graph before optimizing.
**kwargs: **kwargs:
Keyword arguments passed to the ``aesara.graph.optdb.Query`` object. Keyword arguments passed to the ``aesara.graph.optdb.OptimizationQuery`` object.
""" """
from aesara.compile import optdb from aesara.compile import optdb
...@@ -37,7 +37,7 @@ def optimize_graph( ...@@ -37,7 +37,7 @@ def optimize_graph(
fgraph = FunctionGraph(outputs=[fgraph], clone=clone) fgraph = FunctionGraph(outputs=[fgraph], clone=clone)
return_only_out = True return_only_out = True
canonicalize_opt = optdb.query(Query(include=include, **kwargs)) canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph) _ = canonicalize_opt.optimize(fgraph)
if custom_opt: if custom_opt:
......
...@@ -10,7 +10,13 @@ from aesara.misc.ordered_set import OrderedSet ...@@ -10,7 +10,13 @@ from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict from aesara.utils import DefaultOrderedDict
class DB: class OptimizationDatabase:
"""A class that represents a collection/database of optimizations.
These databases can be used to logically organize sets of
(i.e. ``GlobalOptimizer``s and ``LocalOptimizer``)
"""
def __hash__(self): def __hash__(self):
if not hasattr(self, "_optimizer_idx"): if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0] self._optimizer_idx = opt._optimizer_idx[0]
...@@ -24,7 +30,7 @@ class DB: ...@@ -24,7 +30,7 @@ class DB:
# (via obj.name by the thing doing the registering) # (via obj.name by the thing doing the registering)
def register(self, name, obj, *tags, **kwargs): def register(self, name, obj, *tags, **kwargs):
""" """Register a new optimizer to the database.
Parameters Parameters
---------- ----------
...@@ -35,19 +41,21 @@ class DB: ...@@ -35,19 +41,21 @@ class DB:
tags tags
Tag name that allow to select the optimizer. Tag name that allow to select the optimizer.
kwargs kwargs
If non empty, should contain only use_db_name_as_tag=False. If non empty, should contain only ``use_db_name_as_tag=False``. By
By default, all optimizations registered in EquilibriumDB default, all optimizations registered in ``EquilibriumDB`` are
are selected when the EquilibriumDB name is used as a selected when the ``EquilibriumDB`` name is used as a tag. We do
tag. We do not want this behavior for some optimizer like not want this behavior for some optimizer like
local_remove_all_assert. use_db_name_as_tag=False remove ``local_remove_all_assert``. ``use_db_name_as_tag=False`` removes
that behavior. This mean only the optimizer name and the that behavior. This mean only the optimizer name and the tags
tags specified will enable that optimization. specified will enable that optimization.
""" """
# N.B. obj is not an instance of class `GlobalOptimizer`. # N.B. obj is not an instance of class `GlobalOptimizer`.
# It is an instance of a DB.In the tests for example, # It is an instance of a DB.In the tests for example,
# this is not always the case. # this is not always the case.
if not isinstance(obj, (DB, opt.GlobalOptimizer, opt.LocalOptimizer)): if not isinstance(
obj, (OptimizationDatabase, opt.GlobalOptimizer, opt.LocalOptimizer)
):
raise TypeError("Object cannot be registered in OptDB", obj) raise TypeError("Object cannot be registered in OptDB", obj)
if name in self.__db__: if name in self.__db__:
raise ValueError( raise ValueError(
...@@ -99,8 +107,8 @@ class DB: ...@@ -99,8 +107,8 @@ class DB:
self.__db__[tag].remove(obj) self.__db__[tag].remove(obj)
def __query__(self, q): def __query__(self, q):
if not isinstance(q, Query): if not isinstance(q, OptimizationQuery):
raise TypeError("Expected a Query.", q) raise TypeError("Expected a OptimizationQuery.", q)
# The ordered set is needed for deterministic optimization. # The ordered set is needed for deterministic optimization.
variables = OrderedSet() variables = OrderedSet()
for tag in q.include: for tag in q.include:
...@@ -112,7 +120,7 @@ class DB: ...@@ -112,7 +120,7 @@ class DB:
remove = OrderedSet() remove = OrderedSet()
add = OrderedSet() add = OrderedSet()
for obj in variables: for obj in variables:
if isinstance(obj, DB): if isinstance(obj, OptimizationDatabase):
def_sub_query = q def_sub_query = q
if q.extra_optimizations: if q.extra_optimizations:
def_sub_query = copy.copy(q) def_sub_query = copy.copy(q)
...@@ -128,10 +136,10 @@ class DB: ...@@ -128,10 +136,10 @@ class DB:
return variables return variables
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
if len(tags) >= 1 and isinstance(tags[0], Query): if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
if len(tags) > 1 or kwtags: if len(tags) > 1 or kwtags:
raise TypeError( raise TypeError(
"If the first argument to query is a Query," "If the first argument to query is a OptimizationQuery,"
" there should be no other arguments.", " there should be no other arguments.",
tags, tags,
kwtags, kwtags,
...@@ -147,7 +155,9 @@ class DB: ...@@ -147,7 +155,9 @@ class DB:
tags, tags,
) )
return self.__query__( return self.__query__(
Query(include=include, require=require, exclude=exclude, subquery=kwtags) OptimizationQuery(
include=include, require=require, exclude=exclude, subquery=kwtags
)
) )
def __getitem__(self, name): def __getitem__(self, name):
...@@ -168,7 +178,11 @@ class DB: ...@@ -168,7 +178,11 @@ class DB:
print(" db", self.__db__, file=stream) print(" db", self.__db__, file=stream)
class Query: # This is deprecated and will be removed.
DB = OptimizationDatabase
class OptimizationQuery:
""" """
Parameters Parameters
...@@ -204,7 +218,7 @@ class Query: ...@@ -204,7 +218,7 @@ class Query:
def __str__(self): def __str__(self):
return ( return (
"Query{inc=%s,ex=%s,require=%s,subquery=%s," "OptimizationQuery{inc=%s,ex=%s,require=%s,subquery=%s,"
"position_cutoff=%f,extra_opts=%s}" "position_cutoff=%f,extra_opts=%s}"
% ( % (
self.include, self.include,
...@@ -223,7 +237,7 @@ class Query: ...@@ -223,7 +237,7 @@ class Query:
# add all opt with this tag # add all opt with this tag
def including(self, *tags): def including(self, *tags):
return Query( return OptimizationQuery(
self.include.union(tags), self.include.union(tags),
self.require, self.require,
self.exclude, self.exclude,
...@@ -234,7 +248,7 @@ class Query: ...@@ -234,7 +248,7 @@ class Query:
# remove all opt with this tag # remove all opt with this tag
def excluding(self, *tags): def excluding(self, *tags):
return Query( return OptimizationQuery(
self.include, self.include,
self.require, self.require,
self.exclude.union(tags), self.exclude.union(tags),
...@@ -245,7 +259,7 @@ class Query: ...@@ -245,7 +259,7 @@ class Query:
# keep only opt with this tag. # keep only opt with this tag.
def requiring(self, *tags): def requiring(self, *tags):
return Query( return OptimizationQuery(
self.include, self.include,
self.require.union(tags), self.require.union(tags),
self.exclude, self.exclude,
...@@ -255,7 +269,7 @@ class Query: ...@@ -255,7 +269,7 @@ class Query:
) )
def register(self, *optimizations): def register(self, *optimizations):
return Query( return OptimizationQuery(
self.include, self.include,
self.require, self.require,
self.exclude, self.exclude,
...@@ -265,7 +279,11 @@ class Query: ...@@ -265,7 +279,11 @@ class Query:
) )
class EquilibriumDB(DB): # This is deprecated and will be removed.
Query = OptimizationQuery
class EquilibriumDB(OptimizationDatabase):
""" """
A set of potential optimizations which should be applied in an arbitrary A set of potential optimizations which should be applied in an arbitrary
order until equilibrium is reached. order until equilibrium is reached.
...@@ -331,7 +349,7 @@ class EquilibriumDB(DB): ...@@ -331,7 +349,7 @@ class EquilibriumDB(DB):
) )
class SequenceDB(DB): class SequenceDB(OptimizationDatabase):
""" """
A sequence of potential optimizations. A sequence of potential optimizations.
...@@ -378,13 +396,13 @@ class SequenceDB(DB): ...@@ -378,13 +396,13 @@ class SequenceDB(DB):
position_cutoff = kwtags.pop("position_cutoff", config.optdb__position_cutoff) position_cutoff = kwtags.pop("position_cutoff", config.optdb__position_cutoff)
position_dict = self.__position__ position_dict = self.__position__
if len(tags) >= 1 and isinstance(tags[0], Query): if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
# the call to super should have raise an error with a good message # the call to super should have raise an error with a good message
assert len(tags) == 1 assert len(tags) == 1
if getattr(tags[0], "position_cutoff", None): if getattr(tags[0], "position_cutoff", None):
position_cutoff = tags[0].position_cutoff position_cutoff = tags[0].position_cutoff
# The Query instance might contain extra optimizations which need # The OptimizationQuery instance might contain extra optimizations which need
# to be added the the sequence of optimizations (don't alter the # to be added the the sequence of optimizations (don't alter the
# original dictionary) # original dictionary)
if len(tags[0].extra_optimizations) > 0: if len(tags[0].extra_optimizations) > 0:
...@@ -430,7 +448,7 @@ class SequenceDB(DB): ...@@ -430,7 +448,7 @@ class SequenceDB(DB):
return sio.getvalue() return sio.getvalue()
class LocalGroupDB(DB): class LocalGroupDB(OptimizationDatabase):
""" """
Generate a local optimizer of type LocalOptGroup instead Generate a local optimizer of type LocalOptGroup instead
of a global optimizer. of a global optimizer.
...@@ -476,7 +494,7 @@ class LocalGroupDB(DB): ...@@ -476,7 +494,7 @@ class LocalGroupDB(DB):
return ret return ret
class TopoDB(DB): class TopoDB(OptimizationDatabase):
""" """
Generate a `GlobalOptimizer` of type TopoOptimizer. Generate a `GlobalOptimizer` of type TopoOptimizer.
...@@ -501,7 +519,7 @@ class TopoDB(DB): ...@@ -501,7 +519,7 @@ class TopoDB(DB):
) )
class ProxyDB(DB): class ProxyDB(OptimizationDatabase):
""" """
Wrap an existing proxy. Wrap an existing proxy.
...@@ -511,7 +529,7 @@ class ProxyDB(DB): ...@@ -511,7 +529,7 @@ class ProxyDB(DB):
""" """
def __init__(self, db): def __init__(self, db):
assert isinstance(db, DB), "" assert isinstance(db, OptimizationDatabase), ""
self.db = db self.db = db
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
......
...@@ -395,97 +395,96 @@ Definition of optdb ...@@ -395,97 +395,96 @@ Definition of optdb
optdb is an object which is an instance of optdb is an object which is an instance of
:class:`SequenceDB <optdb.SequenceDB>`, :class:`SequenceDB <optdb.SequenceDB>`,
itself a subclass of :class:`DB <optdb.DB>`. itself a subclass of :class:`OptimizationDatabase <optdb.OptimizationDatabase>`.
There exist (for now) two types of DB, SequenceDB and EquilibriumDB. There exist (for now) two types of OptimizationDatabase, SequenceDB and EquilibriumDB.
When given an appropriate Query, DB objects build an Optimizer matching When given an appropriate OptimizationQuery, OptimizationDatabase objects build an Optimizer matching
the query. the query.
A SequenceDB contains Optimizer or DB objects. Each of them has a A SequenceDB contains Optimizer or OptimizationDatabase objects. Each of them
name, an arbitrary number of tags and an integer representing their has a name, an arbitrary number of tags and an integer representing their order
order in the sequence. When a Query is applied to a SequenceDB, all in the sequence. When a OptimizationQuery is applied to a SequenceDB, all Optimizers whose
Optimizers whose tags match the query are inserted in proper order in tags match the query are inserted in proper order in a SequenceOptimizer, which
a SequenceOptimizer, which is returned. If the SequenceDB contains DB is returned. If the SequenceDB contains OptimizationDatabase instances, the OptimizationQuery will be passed
instances, the Query will be passed to them as well and the optimizers to them as well and the optimizers they return will be put in their places.
they return will be put in their places.
An EquilibriumDB contains LocalOptimizer or DB objects. Each of them An EquilibriumDB contains LocalOptimizer or OptimizationDatabase objects. Each of them
has a name and an arbitrary number of tags. When a Query is applied to has a name and an arbitrary number of tags. When a OptimizationQuery is applied to
an EquilibriumDB, all LocalOptimizers that match the query are an EquilibriumDB, all LocalOptimizers that match the query are
inserted into an EquilibriumOptimizer, which is returned. If the inserted into an EquilibriumOptimizer, which is returned. If the
SequenceDB contains DB instances, the Query will be passed to them as SequenceDB contains OptimizationDatabase instances, the OptimizationQuery will be passed to them as
well and the LocalOptimizers they return will be put in their places well and the LocalOptimizers they return will be put in their places
(note that as of yet no DB can produce LocalOptimizer objects, so this (note that as of yet no OptimizationDatabase can produce LocalOptimizer objects, so this
is a moot point). is a moot point).
Aesara contains one principal DB object, :class:`optdb`, which Aesara contains one principal OptimizationDatabase object, :class:`optdb`, which
contains all of Aesara's optimizers with proper tags. It is contains all of Aesara's optimizers with proper tags. It is
recommended to insert new Optimizers in it. As mentioned previously, recommended to insert new Optimizers in it. As mentioned previously,
optdb is a SequenceDB, so, at the top level, Aesara applies a sequence optdb is a SequenceDB, so, at the top level, Aesara applies a sequence
of global optimizations to the computation graphs. of global optimizations to the computation graphs.
Query OptimizationQuery
----- -----
A Query is built by the following call: A OptimizationQuery is built by the following call:
.. code-block:: python .. code-block:: python
aesara.graph.optdb.Query(include, require=None, exclude=None, subquery=None) aesara.graph.optdb.OptimizationQuery(include, require=None, exclude=None, subquery=None)
.. class:: Query .. class:: OptimizationQuery
.. attribute:: include .. attribute:: include
A set of tags (a tag being a string) such that every A set of tags (a tag being a string) such that every
optimization obtained through this Query must have **one** of the tags optimization obtained through this OptimizationQuery must have **one** of the tags
listed. This field is required and basically acts as a starting point listed. This field is required and basically acts as a starting point
for the search. for the search.
.. attribute:: require .. attribute:: require
A set of tags such that every optimization obtained A set of tags such that every optimization obtained
through this Query must have **all** of these tags. through this OptimizationQuery must have **all** of these tags.
.. attribute:: exclude .. attribute:: exclude
A set of tags such that every optimization obtained A set of tags such that every optimization obtained
through this Query must have **none** of these tags. through this OptimizationQuery must have **none** of these tags.
.. attribute:: subquery .. attribute:: subquery
optdb can contain sub-databases; subquery is a optdb can contain sub-databases; subquery is a
dictionary mapping the name of a sub-database to a special Query. dictionary mapping the name of a sub-database to a special OptimizationQuery.
If no subquery is given for a sub-database, the original Query will be If no subquery is given for a sub-database, the original OptimizationQuery will be
used again. used again.
Furthermore, a Query object includes three methods, ``including``, Furthermore, a OptimizationQuery object includes three methods, ``including``,
``requiring`` and ``excluding`` which each produce a new Query object ``requiring`` and ``excluding`` which each produce a new OptimizationQuery object
with include, require and exclude sets refined to contain the new [WRITEME] with include, require and exclude sets refined to contain the new [WRITEME]
Examples Examples
-------- --------
Here are a few examples of how to use a Query on optdb to produce an Here are a few examples of how to use a OptimizationQuery on optdb to produce an
Optimizer: Optimizer:
.. testcode:: .. testcode::
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
from aesara.compile import optdb from aesara.compile import optdb
# This is how the optimizer for the fast_run mode is defined # This is how the optimizer for the fast_run mode is defined
fast_run = optdb.query(Query(include=['fast_run'])) fast_run = optdb.query(OptimizationQuery(include=['fast_run']))
# This is how the optimizer for the fast_compile mode is defined # This is how the optimizer for the fast_compile mode is defined
fast_compile = optdb.query(Query(include=['fast_compile'])) fast_compile = optdb.query(OptimizationQuery(include=['fast_compile']))
# This is the same as fast_run but no optimizations will replace # This is the same as fast_run but no optimizations will replace
# any operation by an inplace version. This assumes, of course, # any operation by an inplace version. This assumes, of course,
# that all inplace operations are tagged as 'inplace' (as they # that all inplace operations are tagged as 'inplace' (as they
# should!) # should!)
fast_run_no_inplace = optdb.query(Query(include=['fast_run'], fast_run_no_inplace = optdb.query(OptimizationQuery(include=['fast_run'],
exclude=['inplace'])) exclude=['inplace']))
...@@ -544,7 +543,7 @@ optimizations: ...@@ -544,7 +543,7 @@ optimizations:
For each group, all optimizations of the group that are selected by For each group, all optimizations of the group that are selected by
the Query will be applied on the graph over and over again until none the OptimizationQuery will be applied on the graph over and over again until none
of them is applicable, so keep that in mind when designing it: check of them is applicable, so keep that in mind when designing it: check
carefully that your optimization leads to a fixpoint (a point where it carefully that your optimization leads to a fixpoint (a point where it
cannot apply anymore) at which point it returns ``False`` to indicate its cannot apply anymore) at which point it returns ``False`` to indicate its
......
import pytest import pytest
from aesara.graph.optdb import DB, opt from aesara.graph.optdb import OptimizationDatabase, opt
class TestDB: class TestDB:
...@@ -11,7 +11,7 @@ class TestDB: ...@@ -11,7 +11,7 @@ class TestDB:
def apply(self, fgraph): def apply(self, fgraph):
pass pass
db = DB() db = OptimizationDatabase()
db.register("a", Opt()) db.register("a", Opt())
db.register("b", Opt()) db.register("b", Opt())
......
...@@ -13,7 +13,7 @@ from aesara.configdefaults import config ...@@ -13,7 +13,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, get_test_value from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
from aesara.ifelse import ifelse from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker from aesara.link.jax import JAXLinker
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
...@@ -52,7 +52,7 @@ from aesara.tensor.type import ( ...@@ -52,7 +52,7 @@ from aesara.tensor.type import (
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts) jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
...@@ -1057,7 +1057,7 @@ def test_jax_BatchedDot(): ...@@ -1057,7 +1057,7 @@ def test_jax_BatchedDot():
# A dimension mismatch should raise a TypeError for compatibility # A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)] inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts) jax_mode = Mode(JAXLinker(), opts)
aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError): with pytest.raises(TypeError):
......
...@@ -21,7 +21,7 @@ from aesara.compile.sharedvalue import SharedVariable ...@@ -21,7 +21,7 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.numba.dispatch import create_numba_signature, get_numba_type from aesara.link.numba.dispatch import create_numba_signature, get_numba_type
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
...@@ -70,7 +70,7 @@ class MyMultiOut(Op): ...@@ -70,7 +70,7 @@ class MyMultiOut(Op):
outputs[1][0] = res2 outputs[1][0] = res2
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts) numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
......
...@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode ...@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import EquilibriumOptimizer from aesara.graph.opt import EquilibriumOptimizer
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import ( from aesara.tensor.random.basic import (
dirichlet, dirichlet,
...@@ -27,8 +27,10 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte ...@@ -27,8 +27,10 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
from aesara.tensor.type import iscalar, vector from aesara.tensor.type import iscalar, vector
inplace_mode = Mode("py", Query(include=["random_make_inplace"], exclude=[])) inplace_mode = Mode(
no_mode = Mode("py", Query(include=[], exclude=[])) "py", OptimizationQuery(include=["random_make_inplace"], exclude=[])
)
no_mode = Mode("py", OptimizationQuery(include=[], exclude=[]))
def test_inplace_optimization(): def test_inplace_optimization():
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
from aesara import config, function from aesara import config, function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.random.utils import RandomStream, broadcast_params from aesara.tensor.random.utils import RandomStream, broadcast_params
from aesara.tensor.type import matrix, tensor from aesara.tensor.type import matrix, tensor
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -11,7 +11,7 @@ from tests import unittest_tools as utt ...@@ -11,7 +11,7 @@ from tests import unittest_tools as utt
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def set_aesara_flags(): def set_aesara_flags():
opts = Query(include=[None], exclude=[]) opts = OptimizationQuery(include=[None], exclude=[])
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
with config.change_flags(mode=py_mode, compute_test_value="warn"): with config.change_flags(mode=py_mode, compute_test_value="warn"):
yield yield
......
...@@ -19,7 +19,7 @@ from aesara.graph.basic import Apply, Constant ...@@ -19,7 +19,7 @@ from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, local_optimizer, out2in from aesara.graph.opt import check_stack_trace, local_optimizer, out2in
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace from aesara.tensor import inplace
from aesara.tensor.basic import ( from aesara.tensor.basic import (
...@@ -140,15 +140,15 @@ mode_opt = get_mode(mode_opt) ...@@ -140,15 +140,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift = out2in(local_dimshuffle_lift) dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = Query(include=["fast_run"]) _optimizer_stabilize = OptimizationQuery(include=["fast_run"])
_optimizer_stabilize.position_cutoff = 1.51 _optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = optdb.query(_optimizer_stabilize) _optimizer_stabilize = optdb.query(_optimizer_stabilize)
_optimizer_specialize = Query(include=["fast_run"]) _optimizer_specialize = OptimizationQuery(include=["fast_run"])
_optimizer_specialize.position_cutoff = 2.01 _optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = optdb.query(_optimizer_specialize) _optimizer_specialize = optdb.query(_optimizer_specialize)
_optimizer_fast_run = Query(include=["fast_run"]) _optimizer_fast_run = OptimizationQuery(include=["fast_run"])
_optimizer_fast_run = optdb.query(_optimizer_fast_run) _optimizer_fast_run = optdb.query(_optimizer_fast_run)
...@@ -351,7 +351,7 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -351,7 +351,7 @@ def test_local_useless_dimshuffle_in_reshape():
class TestFusion: class TestFusion:
opts = Query( opts = OptimizationQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
...@@ -1125,7 +1125,7 @@ class TestFusion: ...@@ -1125,7 +1125,7 @@ class TestFusion:
def test_add_mul_fusion_inplace(self): def test_add_mul_fusion_inplace(self):
opts = Query( opts = OptimizationQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
......
...@@ -9,7 +9,7 @@ from aesara.compile.mode import Mode ...@@ -9,7 +9,7 @@ from aesara.compile.mode import Mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import grad from aesara.gradient import grad
from aesara.graph.basic import applys_between from aesara.graph.basic import applys_between
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.extra_ops import ( from aesara.tensor.extra_ops import (
Bartlett, Bartlett,
...@@ -1169,7 +1169,7 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1169,7 +1169,7 @@ class TestBroadcastTo(utt.InferShapeTester):
q = b[np.r_[0, 1, 3]] q = b[np.r_[0, 1, 3]]
e = aet.set_subtensor(q, np.r_[0, 0, 0]) e = aet.set_subtensor(q, np.r_[0, 0, 0])
opts = Query(include=["inplace"]) opts = OptimizationQuery(include=["inplace"])
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
e_fn = function([d], e, mode=py_mode) e_fn = function([d], e, mode=py_mode)
......
...@@ -20,7 +20,7 @@ from aesara.graph.basic import Constant ...@@ -20,7 +20,7 @@ from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in
from aesara.graph.opt_utils import is_same_graph from aesara.graph.opt_utils import is_same_graph
from aesara.graph.optdb import Query from aesara.graph.optdb import OptimizationQuery
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join, switch from aesara.tensor.basic import Alloc, join, switch
...@@ -124,15 +124,15 @@ mode_opt = get_mode(mode_opt) ...@@ -124,15 +124,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift = out2in(local_dimshuffle_lift) dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = Query(include=["fast_run"]) _optimizer_stabilize = OptimizationQuery(include=["fast_run"])
_optimizer_stabilize.position_cutoff = 1.51 _optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = optdb.query(_optimizer_stabilize) _optimizer_stabilize = optdb.query(_optimizer_stabilize)
_optimizer_specialize = Query(include=["fast_run"]) _optimizer_specialize = OptimizationQuery(include=["fast_run"])
_optimizer_specialize.position_cutoff = 2.01 _optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = optdb.query(_optimizer_specialize) _optimizer_specialize = optdb.query(_optimizer_specialize)
_optimizer_fast_run = Query(include=["fast_run"]) _optimizer_fast_run = OptimizationQuery(include=["fast_run"])
_optimizer_fast_run = optdb.query(_optimizer_fast_run) _optimizer_fast_run = optdb.query(_optimizer_fast_run)
...@@ -351,7 +351,7 @@ class TestAlgebraicCanonize: ...@@ -351,7 +351,7 @@ class TestAlgebraicCanonize:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other # We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = Query(["canonicalize"]) opt = OptimizationQuery(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion") opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases): for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
...@@ -486,7 +486,7 @@ class TestAlgebraicCanonize: ...@@ -486,7 +486,7 @@ class TestAlgebraicCanonize:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other # We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
mode._optimizer = Query(["canonicalize"]) mode._optimizer = OptimizationQuery(["canonicalize"])
mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion") mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion")
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases): for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
f = function( f = function(
...@@ -534,7 +534,7 @@ class TestAlgebraicCanonize: ...@@ -534,7 +534,7 @@ class TestAlgebraicCanonize:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = Query(["canonicalize"]) opt = OptimizationQuery(["canonicalize"])
opt = opt.including("ShapeOpt", "local_fill_to_alloc") opt = opt.including("ShapeOpt", "local_fill_to_alloc")
opt = opt.excluding("local_elemwise_fusion") opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
...@@ -897,7 +897,7 @@ class TestAlgebraicCanonize: ...@@ -897,7 +897,7 @@ class TestAlgebraicCanonize:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = Query(["canonicalize"]) opt = OptimizationQuery(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion") opt = opt.excluding("local_elemwise_fusion")
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test fail! # test fail!
...@@ -1051,7 +1051,7 @@ def test_cast_in_mul_canonizer(): ...@@ -1051,7 +1051,7 @@ def test_cast_in_mul_canonizer():
class TestFusion: class TestFusion:
opts = Query( opts = OptimizationQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
...@@ -1762,7 +1762,7 @@ class TestFusion: ...@@ -1762,7 +1762,7 @@ class TestFusion:
def test_add_mul_fusion_inplace(self): def test_add_mul_fusion_inplace(self):
opts = Query( opts = OptimizationQuery(
include=[ include=[
"local_elemwise_fusion", "local_elemwise_fusion",
"composite_elemwise_fusion", "composite_elemwise_fusion",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论