提交 ba5336f6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Graph replace: Handle redundant replacements

上级 5946b4de
import warnings import warnings
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch from functools import singledispatch
from typing import cast, overload from typing import cast, overload
from pytensor.graph.basic import ( from pytensor.graph.basic import (
...@@ -169,6 +169,9 @@ def graph_replace( ...@@ -169,6 +169,9 @@ def graph_replace(
fg_replace = {equiv[c]: c for c in conditions} fg_replace = {equiv[c]: c for c in conditions}
# add the replacements on top of input mappings # add the replacements on top of input mappings
fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv}) fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv})
# Filter out replacements whose keys are not in the FunctionGraph
# This can happen when a replacement makes an ancestor replacement redundant
fg_replace = {k: v for k, v in fg_replace.items() if k in fg.variables}
# replacements have to be done in reverse topological order so that nested # replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly # expressions get recursively replaced correctly
...@@ -183,11 +186,13 @@ def graph_replace( ...@@ -183,11 +186,13 @@ def graph_replace(
toposort = fg.toposort() toposort = fg.toposort()
def toposort_key( def toposort_key(
fg: FunctionGraph, ts: list[Apply], pair: tuple[Variable, Variable] pair: tuple[Variable, Variable],
toposort=toposort,
fg=fg,
) -> int: ) -> int:
key, _ = pair key, _ = pair
if key.owner is not None: if (node := key.owner) is not None:
return ts.index(key.owner) return toposort.index(node) # type: ignore[no-any-return]
else: else:
if key in fg.variables: if key in fg.variables:
return -1 return -1
...@@ -197,7 +202,7 @@ def graph_replace( ...@@ -197,7 +202,7 @@ def graph_replace(
sorted_replacements = sorted( sorted_replacements = sorted(
fg_replace.items(), fg_replace.items(),
# sort based on the fg toposort, if a variable has no owner, it goes first # sort based on the fg toposort, if a variable has no owner, it goes first
key=partial(toposort_key, fg, toposort), key=toposort_key,
reverse=True, reverse=True,
) )
fg.replace_all(sorted_replacements, import_missing=True) fg.replace_all(sorted_replacements, import_missing=True)
......
...@@ -15,6 +15,7 @@ from pytensor.graph.traversal import graph_inputs ...@@ -15,6 +15,7 @@ from pytensor.graph.traversal import graph_inputs
from pytensor.tensor import dvector, fvector, vector from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs
from tests.unittest_tools import assert_equal_computations
class TestCloneReplace: class TestCloneReplace:
...@@ -233,6 +234,24 @@ class TestGraphReplace: ...@@ -233,6 +234,24 @@ class TestGraphReplace:
with pytest.raises(ValueError, match="Some replacements were not used"): with pytest.raises(ValueError, match="Some replacements were not used"):
graph_replace([out], {fake: x.clone()}, strict=True) graph_replace([out], {fake: x.clone()}, strict=True)
def test_replace_var_and_ancestor(self):
"""Replacing both a variable and its ancestor should not crash.
When x depends on a and y only depends on a through x,
replacing both x and a should work: x->xx makes a->aa a no-op.
"""
op = MyOp("op")
a = MyVariable("a")
x = op(a) # x depends on a
y = op(x) # y depends on x (and transitively on a)
new_a = MyVariable("new_a")
new_x = MyVariable("new_x")
[new_y] = graph_replace([y], {a: new_a, x: new_x})
assert new_y.owner.inputs[0] is new_x
assert_equal_computations([new_y], [op(new_x)])
class TestVectorizeGraph: class TestVectorizeGraph:
def test_basic(self): def test_basic(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论