Unverified 提交 5fb56bab authored 作者: Adhvaith Hundi's avatar Adhvaith Hundi 提交者: GitHub

Add overwrite_existing flag (#1119)

* Add 'overwrite_existing' flag to allow graph rewrites and include appropriate testing * Encapsulate test rewriters and use user-facing API --------- Co-authored-by: 's avatarRicardo Vieira <ricardo.vieira1994@gmail.com>
上级 b065112b
......@@ -35,6 +35,7 @@ class RewriteDatabase:
rewriter: Union["RewriteDatabase", RewritesType],
*tags: str,
use_db_name_as_tag=True,
overwrite_existing=False,
):
"""Register a new rewriter to the database.
......@@ -56,7 +57,8 @@ class RewriteDatabase:
``local_remove_all_assert``. Setting `use_db_name_as_tag` to
``False`` removes that behavior. This means that only the rewrite's name
and/or its tags will enable it.
overwrite_existing:
Overwrite the existing rewriter with a new one having the same name
"""
if not isinstance(
rewriter,
......@@ -66,22 +68,27 @@ class RewriteDatabase:
):
raise TypeError(f"{rewriter} is not a valid rewrite type.")
if name in self.__db__:
raise ValueError(f"The tag '{name}' is already present in the database.")
if use_db_name_as_tag:
if self.name is not None:
tags = (*tags, self.name)
rewriter.name = name
# This restriction is there because in many place we suppose that
# something in the RewriteDatabase is there only once.
if rewriter.name in self.__db__:
raise ValueError(
f"Tried to register {rewriter.name} again under the new name {name}. "
"The same rewrite cannot be registered multiple times in"
" an `RewriteDatabase`; use `ProxyDB` instead."
)
# if tag collides with name
if name in self.__db__ and name not in self._names:
raise ValueError(f"The tag '{name}' is already present in the database.")
if name in self.__db__ or rewriter.name in self.__db__:
if overwrite_existing:
self.remove_tags(name, *tags)
old_rewriter = self.__db__[name].pop()
self._names.remove(name)
self.__db__[old_rewriter.__class__.__name__].remove(old_rewriter)
else:
raise ValueError(
f"The tag '{name}' is already present in the database."
)
self.__db__[name] = OrderedSet([rewriter])
self._names.add(name)
self.__db__[rewriter.__class__.__name__].add(rewriter)
......
import pytest
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter
from pytensor.graph.rewriting.db import (
EquilibriumDB,
......@@ -17,6 +18,13 @@ class TestRewriter(GraphRewriter):
pass
class NewTestRewriter(GraphRewriter):
name = "bleh"
def apply(self, fgraph):
pass
class TestDB:
def test_register(self):
db = RewriteDatabase()
......@@ -31,7 +39,7 @@ class TestDB:
assert "c" in db
with pytest.raises(ValueError, match=r"The tag.*"):
db.register("c", TestRewriter()) # name taken
db.register("c", NewTestRewriter()) # name taken
with pytest.raises(ValueError, match=r"The tag.*"):
db.register("z", TestRewriter()) # name collides with tag
......@@ -42,6 +50,40 @@ class TestDB:
with pytest.raises(TypeError, match=r".* is not a valid.*"):
db.register("d", 1)
def test_overwrite_existing(self):
class TestOverwrite1(GraphRewriter):
def apply(self, fgraph):
fgraph.counter[0] += 1
class TestOverwrite2(GraphRewriter):
def apply(self, fgraph):
fgraph.counter[1] += 1
db = SequenceDB()
fg = FunctionGraph([], [])
fg.counter = [0, 0]
db.register("a", TestRewriter(), "basic")
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [0, 0]
with pytest.raises(ValueError, match=r"The tag.*"):
db.register("a", TestOverwrite1(), "basic")
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [0, 0]
db.register("a", TestOverwrite1(), "basic", overwrite_existing=True)
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [1, 0]
db.register("a", TestOverwrite2(), "basic", overwrite_existing=True)
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [1, 1]
def test_EquilibriumDB(self):
eq_db = EquilibriumDB()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论