提交 6bedb2b1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up aesara.graph.optdb docstrings and signatures

上级 c2672deb
import copy import copy
import math import math
import sys import sys
from functools import cmp_to_key
from io import StringIO from io import StringIO
from typing import Dict, Optional from typing import Dict, Optional, Sequence, Union
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import opt from aesara.graph import opt
...@@ -10,78 +11,75 @@ from aesara.misc.ordered_set import OrderedSet ...@@ -10,78 +11,75 @@ from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict from aesara.utils import DefaultOrderedDict
OptimizersType = Union[opt.GlobalOptimizer, opt.LocalOptimizer]
class OptimizationDatabase: class OptimizationDatabase:
"""A class that represents a collection/database of optimizations. """A class that represents a collection/database of optimizations.
These databases can be used to logically organize sets of These databases are used to logically organize collections of optimizers
(i.e. ``GlobalOptimizer``s and ``LocalOptimizer``) (i.e. ``GlobalOptimizer``s and ``LocalOptimizer``).
""" """
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0]
opt._optimizer_idx[0] += 1
return self._optimizer_idx
def __init__(self): def __init__(self):
self.__db__ = DefaultOrderedDict(OrderedSet) self.__db__ = DefaultOrderedDict(OrderedSet)
self._names = set() self._names = set()
self.name = None # will be reset by register # This will be reset by `self.register` (via `obj.name` by the thing
# (via obj.name by the thing doing the registering) # doing the registering)
self.name = None
def register(self, name, obj, *tags, **kwargs): def register(
self,
name: str,
optimizer: Union["OptimizationDatabase", OptimizersType],
*tags: str,
use_db_name_as_tag=True,
):
"""Register a new optimizer to the database. """Register a new optimizer to the database.
Parameters Parameters
---------- ----------
name : str name:
Name of the optimizer. Name of the optimizer.
obj opt:
The optimizer to register. The optimizer to register.
tags tags:
Tag name that allow to select the optimizer. Tag name that allow to select the optimizer.
kwargs use_db_name_as_tag:
If non empty, should contain only ``use_db_name_as_tag=False``. By Add the database's name as a tag, so that its name can be used in a
default, all optimizations registered in ``EquilibriumDB`` are query.
selected when the ``EquilibriumDB`` name is used as a tag. We do By default, all optimizations registered in ``EquilibriumDB`` are
not want this behavior for some optimizer like selected when the ``"EquilibriumDB"`` name is used as a tag. We do
``local_remove_all_assert``. ``use_db_name_as_tag=False`` removes not want this behavior for some optimizers like
that behavior. This mean only the optimizer name and the tags ``local_remove_all_assert``. Setting `use_db_name_as_tag` to
specified will enable that optimization. ``False`` removes that behavior. This mean only the optimizer name
and the tags specified will enable that optimization.
""" """
# N.B. obj is not an instance of class `GlobalOptimizer`.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if not isinstance( if not isinstance(
obj, (OptimizationDatabase, opt.GlobalOptimizer, opt.LocalOptimizer) optimizer, (OptimizationDatabase, opt.GlobalOptimizer, opt.LocalOptimizer)
): ):
raise TypeError("Object cannot be registered in OptDB", obj) raise TypeError(f"{optimizer} is not a valid optimizer type.")
if name in self.__db__: if name in self.__db__:
raise ValueError( raise ValueError(f"The tag '{name}' is already present in the database.")
"The name of the object cannot be an existing"
" tag or the name of another existing object.", if use_db_name_as_tag:
obj,
name,
)
if kwargs:
assert "use_db_name_as_tag" in kwargs
assert kwargs["use_db_name_as_tag"] is False
else:
if self.name is not None: if self.name is not None:
tags = tags + (self.name,) tags = tags + (self.name,)
obj.name = name
optimizer.name = name
# This restriction is there because in many place we suppose that # This restriction is there because in many place we suppose that
# something in the DB is there only once. # something in the OptimizationDatabase is there only once.
if obj.name in self.__db__: if optimizer.name in self.__db__:
raise ValueError( raise ValueError(
f"Tried to register {obj.name} again under the new name {name}. " f"Tried to register {optimizer.name} again under the new name {name}. "
"You can't register the same optimization multiple time in a DB. " "The same optimization cannot be registered multiple times in"
"Use ProxyDB to work around that." " an ``OptimizationDatabase``; use ProxyDB instead."
) )
self.__db__[name] = OrderedSet([obj]) self.__db__[name] = OrderedSet([optimizer])
self._names.add(name) self._names.add(name)
self.__db__[obj.__class__.__name__].add(obj) self.__db__[optimizer.__class__.__name__].add(optimizer)
self.add_tags(name, *tags) self.add_tags(name, *tags)
def add_tags(self, name, *tags): def add_tags(self, name, *tags):
...@@ -91,7 +89,7 @@ class OptimizationDatabase: ...@@ -91,7 +89,7 @@ class OptimizationDatabase:
for tag in tags: for tag in tags:
if tag in self._names: if tag in self._names:
raise ValueError( raise ValueError(
"The tag of the object collides with a name.", obj, tag f"The tag '{tag}' for the {obj} collides with an existing name."
) )
self.__db__[tag].add(obj) self.__db__[tag].add(obj)
...@@ -102,13 +100,11 @@ class OptimizationDatabase: ...@@ -102,13 +100,11 @@ class OptimizationDatabase:
for tag in tags: for tag in tags:
if tag in self._names: if tag in self._names:
raise ValueError( raise ValueError(
"The tag of the object collides with a name.", obj, tag f"The tag '{tag}' for the {obj} collides with an existing name."
) )
self.__db__[tag].remove(obj) self.__db__[tag].remove(obj)
def __query__(self, q): def __query__(self, q):
if not isinstance(q, OptimizationQuery):
raise TypeError("Expected a OptimizationQuery.", q)
# The ordered set is needed for deterministic optimization. # The ordered set is needed for deterministic optimization.
variables = OrderedSet() variables = OrderedSet()
for tag in q.include: for tag in q.include:
...@@ -139,10 +135,8 @@ class OptimizationDatabase: ...@@ -139,10 +135,8 @@ class OptimizationDatabase:
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery): if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
if len(tags) > 1 or kwtags: if len(tags) > 1 or kwtags:
raise TypeError( raise TypeError(
"If the first argument to query is a OptimizationQuery," "If the first argument to query is an `OptimizationQuery`,"
" there should be no other arguments.", " there should be no other arguments."
tags,
kwtags,
) )
return self.__query__(tags[0]) return self.__query__(tags[0])
include = [tag[1:] for tag in tags if tag.startswith("+")] include = [tag[1:] for tag in tags if tag.startswith("+")]
...@@ -151,8 +145,7 @@ class OptimizationDatabase: ...@@ -151,8 +145,7 @@ class OptimizationDatabase:
if len(include) + len(require) + len(exclude) < len(tags): if len(include) + len(require) + len(exclude) < len(tags):
raise ValueError( raise ValueError(
"All tags must start with one of the following" "All tags must start with one of the following"
" characters: '+', '&' or '-'", " characters: '+', '&' or '-'"
tags,
) )
return self.__query__( return self.__query__(
OptimizationQuery( OptimizationQuery(
...@@ -177,31 +170,54 @@ class OptimizationDatabase: ...@@ -177,31 +170,54 @@ class OptimizationDatabase:
print(" names", self._names, file=stream) print(" names", self._names, file=stream)
print(" db", self.__db__, file=stream) print(" db", self.__db__, file=stream)
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0]
opt._optimizer_idx[0] += 1
return self._optimizer_idx
# This is deprecated and will be removed. # This is deprecated and will be removed.
DB = OptimizationDatabase DB = OptimizationDatabase
class OptimizationQuery: class OptimizationQuery:
"""An object that specifies a set of optimizations by tag/name."""
def __init__(
self,
include: Sequence[str],
require: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
subquery: Optional[Dict[str, "OptimizationQuery"]] = None,
position_cutoff: float = math.inf,
extra_optimizations: Optional[Sequence[OptimizersType]] = None,
):
""" """
Parameters Parameters
---------- ==========
position_cutoff : float include:
Used by SequenceDB to keep only optimizer that are positioned before A set of tags such that every optimization obtained through this
the cut_off point. ``OptimizationQuery`` must have **one** of the tags listed. This
field is required and basically acts as a starting point for the
search.
require:
A set of tags such that every optimization obtained through this
``OptimizationQuery`` must have **all** of these tags.
exclude:
A set of tags such that every optimization obtained through this
``OptimizationQuery`` must have **none** of these tags.
subquery:
A dictionary mapping the name of a sub-database to a special
``OptimizationQuery``. If no subquery is given for a sub-database,
the original ``OptimizationQuery`` will be used again.
position_cutoff:
Only optimizations with position less than the cutoff are returned.
extra_optimizations:
Extra optimizations to be added.
""" """
def __init__(
self,
include,
require=None,
exclude=None,
subquery=None,
position_cutoff=math.inf,
extra_optimizations=None,
):
self.include = OrderedSet(include) self.include = OrderedSet(include)
self.require = require or OrderedSet() self.require = require or OrderedSet()
self.exclude = exclude or OrderedSet() self.exclude = exclude or OrderedSet()
...@@ -218,16 +234,11 @@ class OptimizationQuery: ...@@ -218,16 +234,11 @@ class OptimizationQuery:
def __str__(self): def __str__(self):
return ( return (
"OptimizationQuery{inc=%s,ex=%s,require=%s,subquery=%s," "OptimizationQuery("
"position_cutoff=%f,extra_opts=%s}" + f"inc={self.include},ex={self.exclude},"
% ( + f"require={self.require},subquery={self.subquery},"
self.include, + f"position_cutoff={self.position_cutoff},"
self.exclude, + f"extra_opts={self.extra_optimizations})"
self.require,
self.subquery,
self.position_cutoff,
self.extra_optimizations,
)
) )
def __setstate__(self, state): def __setstate__(self, state):
...@@ -314,17 +325,29 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -314,17 +325,29 @@ class EquilibriumDB(OptimizationDatabase):
""" """
def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False): def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False):
"""
Parameters
==========
ignore_newtrees:
If False, we will apply local opt on new node introduced during local
optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally.
tracks_on_change_inputs:
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
"""
super().__init__() super().__init__()
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__ = {} self.__final__ = {}
self.__cleanup__ = {} self.__cleanup__ = {}
def register(self, name, obj, *tags, **kwtags): def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwtags):
final_opt = kwtags.pop("final_opt", False) if final_opt and cleanup:
cleanup = kwtags.pop("cleanup", False) raise ValueError("`final_opt` and `cleanup` cannot both be true.")
# An opt should not be final and clean up
assert not (final_opt and cleanup)
super().register(name, obj, *tags, **kwtags) super().register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup self.__cleanup__[name] = cleanup
...@@ -350,8 +373,7 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -350,8 +373,7 @@ class EquilibriumDB(OptimizationDatabase):
class SequenceDB(OptimizationDatabase): class SequenceDB(OptimizationDatabase):
""" """A sequence of potential optimizations.
A sequence of potential optimizations.
Retrieve a sequence of optimizations (a SeqOptimizer) by calling query(). Retrieve a sequence of optimizations (a SeqOptimizer) by calling query().
...@@ -371,18 +393,21 @@ class SequenceDB(OptimizationDatabase): ...@@ -371,18 +393,21 @@ class SequenceDB(OptimizationDatabase):
self.__position__ = {} self.__position__ = {}
self.failure_callback = failure_callback self.failure_callback = failure_callback
def register(self, name, obj, position, *tags): def register(self, name, obj, position: Union[str, int, float], *tags, **kwargs):
super().register(name, obj, *tags) super().register(name, obj, *tags, **kwargs)
if position == "last": if position == "last":
if len(self.__position__) == 0: if len(self.__position__) == 0:
self.__position__[name] = 0 self.__position__[name] = 0
else: else:
self.__position__[name] = max(self.__position__.values()) + 1 self.__position__[name] = max(self.__position__.values()) + 1
else: elif isinstance(position, (int, float)):
assert isinstance(position, ((int,), float))
self.__position__[name] = position self.__position__[name] = position
else:
raise TypeError(f"`position` must be numeric; got {position}")
def query(self, *tags, **kwtags): def query(
self, *tags, position_cutoff: Optional[Union[int, float]] = None, **kwtags
):
""" """
Parameters Parameters
...@@ -393,7 +418,9 @@ class SequenceDB(OptimizationDatabase): ...@@ -393,7 +418,9 @@ class SequenceDB(OptimizationDatabase):
""" """
opts = super().query(*tags, **kwtags) opts = super().query(*tags, **kwtags)
position_cutoff = kwtags.pop("position_cutoff", config.optdb__position_cutoff) if position_cutoff is None:
position_cutoff = config.optdb__position_cutoff
position_dict = self.__position__ position_dict = self.__position__
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery): if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
...@@ -421,10 +448,12 @@ class SequenceDB(OptimizationDatabase): ...@@ -421,10 +448,12 @@ class SequenceDB(OptimizationDatabase):
opts = [o for o in opts if position_dict[o.name] < position_cutoff] opts = [o for o in opts if position_dict[o.name] < position_cutoff]
opts.sort(key=lambda obj: (position_dict[obj.name], obj.name)) opts.sort(key=lambda obj: (position_dict[obj.name], obj.name))
kwargs = {}
if self.failure_callback: if self.failure_callback:
kwargs["failure_callback"] = self.failure_callback ret = self.seq_opt(opts, failure_callback=self.failure_callback)
ret = self.seq_opt(opts, **kwargs) else:
ret = self.seq_opt(opts)
if hasattr(tags[0], "name"): if hasattr(tags[0], "name"):
ret.name = tags[0].name ret.name = tags[0].name
return ret return ret
...@@ -436,11 +465,11 @@ class SequenceDB(OptimizationDatabase): ...@@ -436,11 +465,11 @@ class SequenceDB(OptimizationDatabase):
def c(a, b): def c(a, b):
return (a[1] > b[1]) - (a[1] < b[1]) return (a[1] > b[1]) - (a[1] < b[1])
positions.sort(c) positions.sort(key=cmp_to_key(c))
print(" position", positions, file=stream) print("\tposition", positions, file=stream)
print(" names", self._names, file=stream) print("\tnames", self._names, file=stream)
print(" db", self.__db__, file=stream) print("\tdb", self.__db__, file=stream)
def __str__(self): def __str__(self):
sio = StringIO() sio = StringIO()
...@@ -448,7 +477,7 @@ class SequenceDB(OptimizationDatabase): ...@@ -448,7 +477,7 @@ class SequenceDB(OptimizationDatabase):
return sio.getvalue() return sio.getvalue()
class LocalGroupDB(OptimizationDatabase): class LocalGroupDB(SequenceDB):
""" """
Generate a local optimizer of type LocalOptGroup instead Generate a local optimizer of type LocalOptGroup instead
of a global optimizer. of a global optimizer.
...@@ -463,31 +492,17 @@ class LocalGroupDB(OptimizationDatabase): ...@@ -463,31 +492,17 @@ class LocalGroupDB(OptimizationDatabase):
profile: bool = False, profile: bool = False,
local_opt=opt.LocalOptGroup, local_opt=opt.LocalOptGroup,
): ):
super().__init__() super().__init__(failure_callback=None)
self.failure_callback = None
self.apply_all_opts = apply_all_opts self.apply_all_opts = apply_all_opts
self.profile = profile self.profile = profile
self.__position__: Dict = {}
self.local_opt = local_opt self.local_opt = local_opt
self.__name__: str = "" self.__name__: str = ""
def register(self, name, obj, *tags, **kwargs): def register(self, name, obj, *tags, position="last", **kwargs):
super().register(name, obj, *tags) super().register(name, obj, position, *tags, **kwargs)
position = kwargs.pop("position", "last")
if position == "last":
if len(self.__position__) == 0:
self.__position__[name] = 0
else:
self.__position__[name] = max(self.__position__.values()) + 1
else:
assert isinstance(position, ((int,), float))
self.__position__[name] = position
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
# For the new `useless` optimizer
opts = list(super().query(*tags, **kwtags)) opts = list(super().query(*tags, **kwtags))
opts.sort(key=lambda obj: (self.__position__[obj.name], obj.name))
ret = self.local_opt( ret = self.local_opt(
*opts, apply_all_opts=self.apply_all_opts, profile=self.profile *opts, apply_all_opts=self.apply_all_opts, profile=self.profile
) )
...@@ -495,11 +510,7 @@ class LocalGroupDB(OptimizationDatabase): ...@@ -495,11 +510,7 @@ class LocalGroupDB(OptimizationDatabase):
class TopoDB(OptimizationDatabase): class TopoDB(OptimizationDatabase):
""" """Generate a `GlobalOptimizer` of type TopoOptimizer."""
Generate a `GlobalOptimizer` of type TopoOptimizer.
"""
def __init__( def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
...@@ -520,16 +531,17 @@ class TopoDB(OptimizationDatabase): ...@@ -520,16 +531,17 @@ class TopoDB(OptimizationDatabase):
class ProxyDB(OptimizationDatabase): class ProxyDB(OptimizationDatabase):
""" """A object that wraps an existing ``OptimizationDatabase``.
Wrap an existing proxy.
This is needed as we can't register the same DB mutiple times in This is needed because we can't register the same ``OptimizationDatabase``
different positions in a SequentialDB. multiple times in different positions in a ``SequentialDB``.
""" """
def __init__(self, db): def __init__(self, db):
assert isinstance(db, OptimizationDatabase), "" if not isinstance(db, OptimizationDatabase):
raise TypeError("`db` must be an `OptimizationDatabase`.")
self.db = db self.db = db
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
......
import pytest import pytest
from aesara.graph.optdb import OptimizationDatabase, opt from aesara.graph import opt
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationDatabase,
ProxyDB,
SequenceDB,
)
class TestDB: class TestOpt(opt.GlobalOptimizer):
def test_name_clashes(self):
class Opt(opt.GlobalOptimizer): # inheritance buys __hash__
name = "blah" name = "blah"
def apply(self, fgraph): def apply(self, fgraph):
pass pass
class TestDB:
def test_register(self):
db = OptimizationDatabase() db = OptimizationDatabase()
db.register("a", Opt()) db.register("a", TestOpt())
db.register("b", Opt()) db.register("b", TestOpt())
db.register("c", Opt(), "z", "asdf") db.register("c", TestOpt(), "z", "asdf")
assert "a" in db assert "a" in db
assert "b" in db assert "b" in db
assert "c" in db assert "c" in db
with pytest.raises(ValueError, match=r"The name.*"): with pytest.raises(ValueError, match=r"The tag.*"):
db.register("c", Opt()) # name taken db.register("c", TestOpt()) # name taken
with pytest.raises(ValueError, match=r"The name.*"): with pytest.raises(ValueError, match=r"The tag.*"):
db.register("z", Opt()) # name collides with tag db.register("z", TestOpt()) # name collides with tag
with pytest.raises(ValueError, match=r"The tag.*"): with pytest.raises(ValueError, match=r"The tag.*"):
db.register("u", Opt(), "b") # name new but tag collides with name db.register("u", TestOpt(), "b") # name new but tag collides with name
with pytest.raises(TypeError, match=r".* is not a valid.*"):
db.register("d", 1)
def test_EquilibriumDB(self):
eq_db = EquilibriumDB()
with pytest.raises(ValueError, match=r"`final_opt` and.*"):
eq_db.register("d", TestOpt(), final_opt=True, cleanup=True)
def test_SequenceDB(self):
seq_db = SequenceDB(failure_callback=None)
res = seq_db.query("+a")
assert isinstance(res, opt.SeqOptimizer)
assert res.data == []
seq_db.register("b", TestOpt(), 1)
from io import StringIO
out_file = StringIO()
seq_db.print_summary(stream=out_file)
res = out_file.getvalue()
assert str(id(seq_db)) in res
assert "names {'b'}" in res
with pytest.raises(TypeError, match=r"`position` must be.*"):
seq_db.register("c", TestOpt(), object())
def test_LocalGroupDB(self):
lg_db = LocalGroupDB()
lg_db.register("a", TestOpt(), 1)
assert "a" in lg_db.__position__
with pytest.raises(TypeError, match=r"`position` must be.*"):
lg_db.register("b", TestOpt(), position=object())
def test_ProxyDB(self):
with pytest.raises(TypeError, match=r"`db` must be.*"):
ProxyDB(object())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论