提交 befc177d authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Ricardo Vieira

add graph_replace function

上级 1b673567
......@@ -74,7 +74,7 @@ __api_version__ = 1
# isort: off
from pytensor.graph.basic import Variable
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, graph_replace
# isort: on
......
......@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
clone,
ancestors,
)
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
......
from functools import partial
from typing import (
Collection,
Dict,
......@@ -10,7 +11,8 @@ from typing import (
cast,
)
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph
def clone_replace(
......@@ -58,3 +60,92 @@ def clone_replace(
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs)
def graph_replace(
outputs: Sequence[Variable],
replace: Dict[Variable, Variable],
*,
strict=True,
) -> List[Variable]:
"""Replace variables in ``outputs`` by ``replace``.
Parameters
----------
outputs: Sequence[Variable]
Output graph
replace: Dict[Variable, Variable]
Replace mapping
strict: bool
Raise an error if some replacements were not used
return_unused: bool
Return replacements that were not used
Returns
-------
List[Variable]
Output graph with subgraphs replaced
Raises
------
ValueError
If some replacemens could not be applied and strict is True
"""
# collect minimum graph inputs which is required to compute outputs
# and depend on replacements
# additionally remove constants, they do not matter in clone get equiv
conditions = [
c
for c in truncated_graph_inputs(outputs, replace)
if not isinstance(c, Constant)
]
# for the function graph we need the clean graph where
# inputs do not have owners
# this is exactly the reason to clone conditions
equiv = {c: c.clone(name=f"i-{i}") for i, c in enumerate(conditions)}
# some replace keys may dissapear
# the reason is they are outside the graph
# clone the graph but preserve the equiv mapping
fg = FunctionGraph(
conditions,
outputs,
# clone_get_equiv kwargs
copy_orphans=False,
copy_inputs=False,
memo=equiv,
)
# replace the conditions back
fg_replace = {equiv[c]: c for c in conditions}
# add the replacements on top of input mappings
fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv})
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly
# some replacements may be initially outside the graph
# but later introduced by a replacement
# So far FunctionGraph does these replacements inplace it is thus unsafe
# apply them using fg.replace, it may change the original graph
if strict:
non_fg_replace = {r: v for r, v in replace.items() if r not in equiv}
if non_fg_replace:
raise ValueError(f"Some replacements were not used: {non_fg_replace}")
toposort = fg.toposort()
def toposort_key(fg: FunctionGraph, ts, pair):
key, _ = pair
if key.owner is not None:
return ts.index(key.owner)
else:
if key in fg.variables:
return -1
else:
raise ValueError(f"{key} is not a part of graph")
sorted_replacements = sorted(
tuple(fg_replace.items()),
# sort based on the fg toposort, if a variable has no owner, it goes first
key=partial(toposort_key, fg, toposort),
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)
return list(fg.outputs)
......@@ -4,7 +4,7 @@ import pytest
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
......@@ -133,3 +133,82 @@ class TestCloneReplace:
utt.assert_allclose(
test(x, pt.sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
class TestGraphReplace:
def test_graph_replace(self):
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
w = MyVariable("w")
MyOp("zop")(z)
x2 = MyOp("xop")(x, w)
x2.name = "x2"
y2 = MyOp("yop")(y)
y2.name = "y2"
yc = graph_replace([x2], {x: y2})[0]
assert yc.owner.inputs[0] is y2
# the old reference is kept
assert yc.owner.inputs[1] is w
# test replace itself
yc = graph_replace([x2], {x2: y2})[0]
assert yc is y2
assert yc.owner.inputs[0] is y
assert len(yc.owner.inputs) == 1
# the case where inputs have to be replaced in reverse topological order
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
oc = graph_replace([o], {x: new_x, y2: new_y2})[0]
assert oc.owner.inputs[1] is new_y2
assert oc.owner.inputs[0].owner.inputs[0] is new_x
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
def test_graph_replace_advanced(self):
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
w = MyVariable("w")
z2 = MyOp("zop")(z)
x2 = MyOp("xop")(x, w)
x2.name = "x2"
y2 = MyOp("yop")(y)
y2.name = "y2"
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
new_y21 = MyOp("ny2op")(new_y2)
# now yet another replacement that could only appear after new_y2: z
# show we can do that after the prev clone
# the case where new variable is referenced during the replacements
new_y21 = MyOp("ny2op")(new_y2)
# the reference new_y2: z2 is not a part of the original graph so the replacement is unsafe
oc = graph_replace([o], {x: new_x, y2: new_y21})
oc = graph_replace(oc, {new_y2: z2})[0]
assert oc.owner.inputs[1].owner.inputs[0] is z2
assert oc.owner.inputs[0].owner.inputs[0] is new_x
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
new_z = z.clone(name="z_new")
oc = graph_replace([oc], {z: new_z})[0]
# new reference appear
assert oc.owner.inputs[1].owner.inputs[0] is not z2
assert oc.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[0] is new_x
assert oc.owner.inputs[0].owner.inputs[1] is w
def test_graph_replace_disconnected(self):
x = MyVariable("x")
fake = MyOp("fake")(x)
o = MyOp("o")(x)
oc = graph_replace([o], {fake: x.clone()}, strict=False)
assert oc[0] is o
with pytest.raises(ValueError, match="Some replacements were not used"):
oc = graph_replace([o], {fake: x.clone()}, strict=True)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论