提交 ba5336f6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Graph replace: Handle redundant replacements

上级 5946b4de
import warnings
from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch
from functools import singledispatch
from typing import cast, overload
from pytensor.graph.basic import (
......@@ -169,6 +169,9 @@ def graph_replace(
fg_replace = {equiv[c]: c for c in conditions}
# add the replacements on top of input mappings
fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv})
# Filter out replacements whose keys are not in the FunctionGraph
# This can happen when a replacement makes an ancestor replacement redundant
fg_replace = {k: v for k, v in fg_replace.items() if k in fg.variables}
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly
......@@ -183,11 +186,13 @@ def graph_replace(
toposort = fg.toposort()
def toposort_key(
fg: FunctionGraph, ts: list[Apply], pair: tuple[Variable, Variable]
pair: tuple[Variable, Variable],
toposort=toposort,
fg=fg,
) -> int:
key, _ = pair
if key.owner is not None:
return ts.index(key.owner)
if (node := key.owner) is not None:
return toposort.index(node) # type: ignore[no-any-return]
else:
if key in fg.variables:
return -1
......@@ -197,7 +202,7 @@ def graph_replace(
sorted_replacements = sorted(
fg_replace.items(),
# sort based on the fg toposort, if a variable has no owner, it goes first
key=partial(toposort_key, fg, toposort),
key=toposort_key,
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)
......
......@@ -15,6 +15,7 @@ from pytensor.graph.traversal import graph_inputs
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs
from tests.unittest_tools import assert_equal_computations
class TestCloneReplace:
......@@ -233,6 +234,24 @@ class TestGraphReplace:
with pytest.raises(ValueError, match="Some replacements were not used"):
graph_replace([out], {fake: x.clone()}, strict=True)
def test_replace_var_and_ancestor(self):
"""Replacing both a variable and its ancestor should not crash.
When x depends on a and y only depends on a through x,
replacing both x and a should work: x->xx makes a->aa a no-op.
"""
op = MyOp("op")
a = MyVariable("a")
x = op(a) # x depends on a
y = op(x) # y depends on x (and transitively on a)
new_a = MyVariable("new_a")
new_x = MyVariable("new_x")
[new_y] = graph_replace([y], {a: new_a, x: new_x})
assert new_y.owner.inputs[0] is new_x
assert_equal_computations([new_y], [op(new_x)])
class TestVectorizeGraph:
def test_basic(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论