提交 6bedb2b1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up aesara.graph.optdb docstrings and signatures

上级 c2672deb
import copy
import math
import sys
from functools import cmp_to_key
from io import StringIO
from typing import Dict, Optional
from typing import Dict, Optional, Sequence, Union
from aesara.configdefaults import config
from aesara.graph import opt
......@@ -10,78 +11,75 @@ from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict
OptimizersType = Union[opt.GlobalOptimizer, opt.LocalOptimizer]
class OptimizationDatabase:
"""A class that represents a collection/database of optimizations.
These databases can be used to logically organize sets of
(i.e. ``GlobalOptimizer``s and ``LocalOptimizer``)
These databases are used to logically organize collections of optimizers
(i.e. ``GlobalOptimizer``s and ``LocalOptimizer``).
"""
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0]
opt._optimizer_idx[0] += 1
return self._optimizer_idx
def __init__(self):
self.__db__ = DefaultOrderedDict(OrderedSet)
self._names = set()
self.name = None # will be reset by register
# (via obj.name by the thing doing the registering)
# This will be reset by `self.register` (via `obj.name` by the thing
# doing the registering)
self.name = None
def register(self, name, obj, *tags, **kwargs):
def register(
self,
name: str,
optimizer: Union["OptimizationDatabase", OptimizersType],
*tags: str,
use_db_name_as_tag=True,
):
"""Register a new optimizer to the database.
Parameters
----------
name : str
name:
Name of the optimizer.
obj
opt:
The optimizer to register.
tags
tags:
Tag name that allow to select the optimizer.
kwargs
If non empty, should contain only ``use_db_name_as_tag=False``. By
default, all optimizations registered in ``EquilibriumDB`` are
selected when the ``EquilibriumDB`` name is used as a tag. We do
not want this behavior for some optimizer like
``local_remove_all_assert``. ``use_db_name_as_tag=False`` removes
that behavior. This mean only the optimizer name and the tags
specified will enable that optimization.
use_db_name_as_tag:
Add the database's name as a tag, so that its name can be used in a
query.
By default, all optimizations registered in ``EquilibriumDB`` are
selected when the ``"EquilibriumDB"`` name is used as a tag. We do
not want this behavior for some optimizers like
``local_remove_all_assert``. Setting `use_db_name_as_tag` to
``False`` removes that behavior. This mean only the optimizer name
and the tags specified will enable that optimization.
"""
# N.B. obj is not an instance of class `GlobalOptimizer`.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if not isinstance(
obj, (OptimizationDatabase, opt.GlobalOptimizer, opt.LocalOptimizer)
optimizer, (OptimizationDatabase, opt.GlobalOptimizer, opt.LocalOptimizer)
):
raise TypeError("Object cannot be registered in OptDB", obj)
raise TypeError(f"{optimizer} is not a valid optimizer type.")
if name in self.__db__:
raise ValueError(
"The name of the object cannot be an existing"
" tag or the name of another existing object.",
obj,
name,
)
if kwargs:
assert "use_db_name_as_tag" in kwargs
assert kwargs["use_db_name_as_tag"] is False
else:
raise ValueError(f"The tag '{name}' is already present in the database.")
if use_db_name_as_tag:
if self.name is not None:
tags = tags + (self.name,)
obj.name = name
optimizer.name = name
# This restriction is there because in many place we suppose that
# something in the DB is there only once.
if obj.name in self.__db__:
# something in the OptimizationDatabase is there only once.
if optimizer.name in self.__db__:
raise ValueError(
f"Tried to register {obj.name} again under the new name {name}. "
"You can't register the same optimization multiple time in a DB. "
"Use ProxyDB to work around that."
f"Tried to register {optimizer.name} again under the new name {name}. "
"The same optimization cannot be registered multiple times in"
" an ``OptimizationDatabase``; use ProxyDB instead."
)
self.__db__[name] = OrderedSet([obj])
self.__db__[name] = OrderedSet([optimizer])
self._names.add(name)
self.__db__[obj.__class__.__name__].add(obj)
self.__db__[optimizer.__class__.__name__].add(optimizer)
self.add_tags(name, *tags)
def add_tags(self, name, *tags):
......@@ -91,7 +89,7 @@ class OptimizationDatabase:
for tag in tags:
if tag in self._names:
raise ValueError(
"The tag of the object collides with a name.", obj, tag
f"The tag '{tag}' for the {obj} collides with an existing name."
)
self.__db__[tag].add(obj)
......@@ -102,13 +100,11 @@ class OptimizationDatabase:
for tag in tags:
if tag in self._names:
raise ValueError(
"The tag of the object collides with a name.", obj, tag
f"The tag '{tag}' for the {obj} collides with an existing name."
)
self.__db__[tag].remove(obj)
def __query__(self, q):
if not isinstance(q, OptimizationQuery):
raise TypeError("Expected a OptimizationQuery.", q)
# The ordered set is needed for deterministic optimization.
variables = OrderedSet()
for tag in q.include:
......@@ -139,10 +135,8 @@ class OptimizationDatabase:
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
if len(tags) > 1 or kwtags:
raise TypeError(
"If the first argument to query is a OptimizationQuery,"
" there should be no other arguments.",
tags,
kwtags,
"If the first argument to query is an `OptimizationQuery`,"
" there should be no other arguments."
)
return self.__query__(tags[0])
include = [tag[1:] for tag in tags if tag.startswith("+")]
......@@ -151,8 +145,7 @@ class OptimizationDatabase:
if len(include) + len(require) + len(exclude) < len(tags):
raise ValueError(
"All tags must start with one of the following"
" characters: '+', '&' or '-'",
tags,
" characters: '+', '&' or '-'"
)
return self.__query__(
OptimizationQuery(
......@@ -177,31 +170,54 @@ class OptimizationDatabase:
print(" names", self._names, file=stream)
print(" db", self.__db__, file=stream)
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0]
opt._optimizer_idx[0] += 1
return self._optimizer_idx
# This is deprecated and will be removed.
DB = OptimizationDatabase
class OptimizationQuery:
"""
Parameters
----------
position_cutoff : float
Used by SequenceDB to keep only optimizer that are positioned before
the cut_off point.
"""
"""An object that specifies a set of optimizations by tag/name."""
def __init__(
self,
include,
require=None,
exclude=None,
subquery=None,
position_cutoff=math.inf,
extra_optimizations=None,
include: Sequence[str],
require: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
subquery: Optional[Dict[str, "OptimizationQuery"]] = None,
position_cutoff: float = math.inf,
extra_optimizations: Optional[Sequence[OptimizersType]] = None,
):
"""
Parameters
==========
include:
A set of tags such that every optimization obtained through this
``OptimizationQuery`` must have **one** of the tags listed. This
field is required and basically acts as a starting point for the
search.
require:
A set of tags such that every optimization obtained through this
``OptimizationQuery`` must have **all** of these tags.
exclude:
A set of tags such that every optimization obtained through this
``OptimizationQuery`` must have **none** of these tags.
subquery:
A dictionary mapping the name of a sub-database to a special
``OptimizationQuery``. If no subquery is given for a sub-database,
the original ``OptimizationQuery`` will be used again.
position_cutoff:
Only optimizations with position less than the cutoff are returned.
extra_optimizations:
Extra optimizations to be added.
"""
self.include = OrderedSet(include)
self.require = require or OrderedSet()
self.exclude = exclude or OrderedSet()
......@@ -218,16 +234,11 @@ class OptimizationQuery:
def __str__(self):
return (
"OptimizationQuery{inc=%s,ex=%s,require=%s,subquery=%s,"
"position_cutoff=%f,extra_opts=%s}"
% (
self.include,
self.exclude,
self.require,
self.subquery,
self.position_cutoff,
self.extra_optimizations,
)
"OptimizationQuery("
+ f"inc={self.include},ex={self.exclude},"
+ f"require={self.require},subquery={self.subquery},"
+ f"position_cutoff={self.position_cutoff},"
+ f"extra_opts={self.extra_optimizations})"
)
def __setstate__(self, state):
......@@ -314,17 +325,29 @@ class EquilibriumDB(OptimizationDatabase):
"""
def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False):
"""
Parameters
==========
ignore_newtrees:
If False, we will apply local opt on new node introduced during local
optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally.
tracks_on_change_inputs:
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
"""
super().__init__()
self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__ = {}
self.__cleanup__ = {}
def register(self, name, obj, *tags, **kwtags):
final_opt = kwtags.pop("final_opt", False)
cleanup = kwtags.pop("cleanup", False)
# An opt should not be final and clean up
assert not (final_opt and cleanup)
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwtags):
if final_opt and cleanup:
raise ValueError("`final_opt` and `cleanup` cannot both be true.")
super().register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup
......@@ -350,8 +373,7 @@ class EquilibriumDB(OptimizationDatabase):
class SequenceDB(OptimizationDatabase):
"""
A sequence of potential optimizations.
"""A sequence of potential optimizations.
Retrieve a sequence of optimizations (a SeqOptimizer) by calling query().
......@@ -371,18 +393,21 @@ class SequenceDB(OptimizationDatabase):
self.__position__ = {}
self.failure_callback = failure_callback
def register(self, name, obj, position, *tags):
super().register(name, obj, *tags)
def register(self, name, obj, position: Union[str, int, float], *tags, **kwargs):
super().register(name, obj, *tags, **kwargs)
if position == "last":
if len(self.__position__) == 0:
self.__position__[name] = 0
else:
self.__position__[name] = max(self.__position__.values()) + 1
else:
assert isinstance(position, ((int,), float))
elif isinstance(position, (int, float)):
self.__position__[name] = position
else:
raise TypeError(f"`position` must be numeric; got {position}")
def query(self, *tags, **kwtags):
def query(
self, *tags, position_cutoff: Optional[Union[int, float]] = None, **kwtags
):
"""
Parameters
......@@ -393,7 +418,9 @@ class SequenceDB(OptimizationDatabase):
"""
opts = super().query(*tags, **kwtags)
position_cutoff = kwtags.pop("position_cutoff", config.optdb__position_cutoff)
if position_cutoff is None:
position_cutoff = config.optdb__position_cutoff
position_dict = self.__position__
if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery):
......@@ -421,10 +448,12 @@ class SequenceDB(OptimizationDatabase):
opts = [o for o in opts if position_dict[o.name] < position_cutoff]
opts.sort(key=lambda obj: (position_dict[obj.name], obj.name))
kwargs = {}
if self.failure_callback:
kwargs["failure_callback"] = self.failure_callback
ret = self.seq_opt(opts, **kwargs)
ret = self.seq_opt(opts, failure_callback=self.failure_callback)
else:
ret = self.seq_opt(opts)
if hasattr(tags[0], "name"):
ret.name = tags[0].name
return ret
......@@ -436,11 +465,11 @@ class SequenceDB(OptimizationDatabase):
def c(a, b):
return (a[1] > b[1]) - (a[1] < b[1])
positions.sort(c)
positions.sort(key=cmp_to_key(c))
print(" position", positions, file=stream)
print(" names", self._names, file=stream)
print(" db", self.__db__, file=stream)
print("\tposition", positions, file=stream)
print("\tnames", self._names, file=stream)
print("\tdb", self.__db__, file=stream)
def __str__(self):
sio = StringIO()
......@@ -448,7 +477,7 @@ class SequenceDB(OptimizationDatabase):
return sio.getvalue()
class LocalGroupDB(OptimizationDatabase):
class LocalGroupDB(SequenceDB):
"""
Generate a local optimizer of type LocalOptGroup instead
of a global optimizer.
......@@ -463,31 +492,17 @@ class LocalGroupDB(OptimizationDatabase):
profile: bool = False,
local_opt=opt.LocalOptGroup,
):
super().__init__()
self.failure_callback = None
super().__init__(failure_callback=None)
self.apply_all_opts = apply_all_opts
self.profile = profile
self.__position__: Dict = {}
self.local_opt = local_opt
self.__name__: str = ""
def register(self, name, obj, *tags, **kwargs):
super().register(name, obj, *tags)
position = kwargs.pop("position", "last")
if position == "last":
if len(self.__position__) == 0:
self.__position__[name] = 0
else:
self.__position__[name] = max(self.__position__.values()) + 1
else:
assert isinstance(position, ((int,), float))
self.__position__[name] = position
def register(self, name, obj, *tags, position="last", **kwargs):
super().register(name, obj, position, *tags, **kwargs)
def query(self, *tags, **kwtags):
# For the new `useless` optimizer
opts = list(super().query(*tags, **kwtags))
opts.sort(key=lambda obj: (self.__position__[obj.name], obj.name))
ret = self.local_opt(
*opts, apply_all_opts=self.apply_all_opts, profile=self.profile
)
......@@ -495,11 +510,7 @@ class LocalGroupDB(OptimizationDatabase):
class TopoDB(OptimizationDatabase):
"""
Generate a `GlobalOptimizer` of type TopoOptimizer.
"""
"""Generate a `GlobalOptimizer` of type TopoOptimizer."""
def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
......@@ -520,16 +531,17 @@ class TopoDB(OptimizationDatabase):
class ProxyDB(OptimizationDatabase):
"""
Wrap an existing proxy.
"""A object that wraps an existing ``OptimizationDatabase``.
This is needed as we can't register the same DB mutiple times in
different positions in a SequentialDB.
This is needed because we can't register the same ``OptimizationDatabase``
multiple times in different positions in a ``SequentialDB``.
"""
def __init__(self, db):
assert isinstance(db, OptimizationDatabase), ""
if not isinstance(db, OptimizationDatabase):
raise TypeError("`db` must be an `OptimizationDatabase`.")
self.db = db
def query(self, *tags, **kwtags):
......
import pytest
from aesara.graph.optdb import OptimizationDatabase, opt
from aesara.graph import opt
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationDatabase,
ProxyDB,
SequenceDB,
)
class TestDB:
def test_name_clashes(self):
class Opt(opt.GlobalOptimizer): # inheritance buys __hash__
name = "blah"
class TestOpt(opt.GlobalOptimizer):
name = "blah"
def apply(self, fgraph):
pass
def apply(self, fgraph):
pass
class TestDB:
def test_register(self):
db = OptimizationDatabase()
db.register("a", Opt())
db.register("a", TestOpt())
db.register("b", Opt())
db.register("b", TestOpt())
db.register("c", Opt(), "z", "asdf")
db.register("c", TestOpt(), "z", "asdf")
assert "a" in db
assert "b" in db
assert "c" in db
with pytest.raises(ValueError, match=r"The name.*"):
db.register("c", Opt()) # name taken
with pytest.raises(ValueError, match=r"The tag.*"):
db.register("c", TestOpt()) # name taken
with pytest.raises(ValueError, match=r"The name.*"):
db.register("z", Opt()) # name collides with tag
with pytest.raises(ValueError, match=r"The tag.*"):
db.register("z", TestOpt()) # name collides with tag
with pytest.raises(ValueError, match=r"The tag.*"):
db.register("u", Opt(), "b") # name new but tag collides with name
db.register("u", TestOpt(), "b") # name new but tag collides with name
with pytest.raises(TypeError, match=r".* is not a valid.*"):
db.register("d", 1)
def test_EquilibriumDB(self):
eq_db = EquilibriumDB()
with pytest.raises(ValueError, match=r"`final_opt` and.*"):
eq_db.register("d", TestOpt(), final_opt=True, cleanup=True)
def test_SequenceDB(self):
seq_db = SequenceDB(failure_callback=None)
res = seq_db.query("+a")
assert isinstance(res, opt.SeqOptimizer)
assert res.data == []
seq_db.register("b", TestOpt(), 1)
from io import StringIO
out_file = StringIO()
seq_db.print_summary(stream=out_file)
res = out_file.getvalue()
assert str(id(seq_db)) in res
assert "names {'b'}" in res
with pytest.raises(TypeError, match=r"`position` must be.*"):
seq_db.register("c", TestOpt(), object())
def test_LocalGroupDB(self):
lg_db = LocalGroupDB()
lg_db.register("a", TestOpt(), 1)
assert "a" in lg_db.__position__
with pytest.raises(TypeError, match=r"`position` must be.*"):
lg_db.register("b", TestOpt(), position=object())
def test_ProxyDB(self):
with pytest.raises(TypeError, match=r"`db` must be.*"):
ProxyDB(object())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论