Unverified 提交 5fb56bab authored 作者: Adhvaith Hundi's avatar Adhvaith Hundi 提交者: GitHub

Add overwrite_existing flag (#1119)

* Add 'overwrite_existing' flag to allow graph rewrites and include appropriate testing * Encapsulate test rewriters and use user-facing API --------- Co-authored-by: 's avatarRicardo Vieira <ricardo.vieira1994@gmail.com>
上级 b065112b
...@@ -35,6 +35,7 @@ class RewriteDatabase: ...@@ -35,6 +35,7 @@ class RewriteDatabase:
rewriter: Union["RewriteDatabase", RewritesType], rewriter: Union["RewriteDatabase", RewritesType],
*tags: str, *tags: str,
use_db_name_as_tag=True, use_db_name_as_tag=True,
overwrite_existing=False,
): ):
"""Register a new rewriter to the database. """Register a new rewriter to the database.
...@@ -56,7 +57,8 @@ class RewriteDatabase: ...@@ -56,7 +57,8 @@ class RewriteDatabase:
``local_remove_all_assert``. Setting `use_db_name_as_tag` to ``local_remove_all_assert``. Setting `use_db_name_as_tag` to
``False`` removes that behavior. This means that only the rewrite's name ``False`` removes that behavior. This means that only the rewrite's name
and/or its tags will enable it. and/or its tags will enable it.
overwrite_existing:
Overwrite the existing rewriter with a new one having the same name
""" """
if not isinstance( if not isinstance(
rewriter, rewriter,
...@@ -66,22 +68,27 @@ class RewriteDatabase: ...@@ -66,22 +68,27 @@ class RewriteDatabase:
): ):
raise TypeError(f"{rewriter} is not a valid rewrite type.") raise TypeError(f"{rewriter} is not a valid rewrite type.")
if name in self.__db__:
raise ValueError(f"The tag '{name}' is already present in the database.")
if use_db_name_as_tag: if use_db_name_as_tag:
if self.name is not None: if self.name is not None:
tags = (*tags, self.name) tags = (*tags, self.name)
rewriter.name = name rewriter.name = name
# This restriction is there because in many place we suppose that
# something in the RewriteDatabase is there only once. # if tag collides with name
if rewriter.name in self.__db__: if name in self.__db__ and name not in self._names:
raise ValueError( raise ValueError(f"The tag '{name}' is already present in the database.")
f"Tried to register {rewriter.name} again under the new name {name}. "
"The same rewrite cannot be registered multiple times in" if name in self.__db__ or rewriter.name in self.__db__:
" an `RewriteDatabase`; use `ProxyDB` instead." if overwrite_existing:
) self.remove_tags(name, *tags)
old_rewriter = self.__db__[name].pop()
self._names.remove(name)
self.__db__[old_rewriter.__class__.__name__].remove(old_rewriter)
else:
raise ValueError(
f"The tag '{name}' is already present in the database."
)
self.__db__[name] = OrderedSet([rewriter]) self.__db__[name] = OrderedSet([rewriter])
self._names.add(name) self._names.add(name)
self.__db__[rewriter.__class__.__name__].add(rewriter) self.__db__[rewriter.__class__.__name__].add(rewriter)
......
import pytest import pytest
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter from pytensor.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter
from pytensor.graph.rewriting.db import ( from pytensor.graph.rewriting.db import (
EquilibriumDB, EquilibriumDB,
...@@ -17,6 +18,13 @@ class TestRewriter(GraphRewriter): ...@@ -17,6 +18,13 @@ class TestRewriter(GraphRewriter):
pass pass
class NewTestRewriter(GraphRewriter):
name = "bleh"
def apply(self, fgraph):
pass
class TestDB: class TestDB:
def test_register(self): def test_register(self):
db = RewriteDatabase() db = RewriteDatabase()
...@@ -31,7 +39,7 @@ class TestDB: ...@@ -31,7 +39,7 @@ class TestDB:
assert "c" in db assert "c" in db
with pytest.raises(ValueError, match=r"The tag.*"): with pytest.raises(ValueError, match=r"The tag.*"):
db.register("c", TestRewriter()) # name taken db.register("c", NewTestRewriter()) # name taken
with pytest.raises(ValueError, match=r"The tag.*"): with pytest.raises(ValueError, match=r"The tag.*"):
db.register("z", TestRewriter()) # name collides with tag db.register("z", TestRewriter()) # name collides with tag
...@@ -42,6 +50,40 @@ class TestDB: ...@@ -42,6 +50,40 @@ class TestDB:
with pytest.raises(TypeError, match=r".* is not a valid.*"): with pytest.raises(TypeError, match=r".* is not a valid.*"):
db.register("d", 1) db.register("d", 1)
def test_overwrite_existing(self):
class TestOverwrite1(GraphRewriter):
def apply(self, fgraph):
fgraph.counter[0] += 1
class TestOverwrite2(GraphRewriter):
def apply(self, fgraph):
fgraph.counter[1] += 1
db = SequenceDB()
fg = FunctionGraph([], [])
fg.counter = [0, 0]
db.register("a", TestRewriter(), "basic")
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [0, 0]
with pytest.raises(ValueError, match=r"The tag.*"):
db.register("a", TestOverwrite1(), "basic")
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [0, 0]
db.register("a", TestOverwrite1(), "basic", overwrite_existing=True)
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [1, 0]
db.register("a", TestOverwrite2(), "basic", overwrite_existing=True)
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [1, 1]
def test_EquilibriumDB(self): def test_EquilibriumDB(self):
eq_db = EquilibriumDB() eq_db = EquilibriumDB()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论