提交 befc177d authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Ricardo Vieira

add graph_replace function

上级 1b673567
...@@ -74,7 +74,7 @@ __api_version__ = 1 ...@@ -74,7 +74,7 @@ __api_version__ = 1
# isort: off # isort: off
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace, graph_replace
# isort: on # isort: on
......
...@@ -9,7 +9,7 @@ from pytensor.graph.basic import ( ...@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
clone, clone,
ancestors, ancestors,
) )
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
......
from functools import partial
from typing import ( from typing import (
Collection, Collection,
Dict, Dict,
...@@ -10,7 +11,8 @@ from typing import ( ...@@ -10,7 +11,8 @@ from typing import (
cast, cast,
) )
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph
def clone_replace( def clone_replace(
...@@ -58,3 +60,92 @@ def clone_replace( ...@@ -58,3 +60,92 @@ def clone_replace(
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds) _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs) return cast(List[Variable], outs)
def graph_replace(
outputs: Sequence[Variable],
replace: Dict[Variable, Variable],
*,
strict=True,
) -> List[Variable]:
"""Replace variables in ``outputs`` by ``replace``.
Parameters
----------
outputs: Sequence[Variable]
Output graph
replace: Dict[Variable, Variable]
Replace mapping
strict: bool
Raise an error if some replacements were not used
return_unused: bool
Return replacements that were not used
Returns
-------
List[Variable]
Output graph with subgraphs replaced
Raises
------
ValueError
If some replacemens could not be applied and strict is True
"""
# collect minimum graph inputs which is required to compute outputs
# and depend on replacements
# additionally remove constants, they do not matter in clone get equiv
conditions = [
c
for c in truncated_graph_inputs(outputs, replace)
if not isinstance(c, Constant)
]
# for the function graph we need the clean graph where
# inputs do not have owners
# this is exactly the reason to clone conditions
equiv = {c: c.clone(name=f"i-{i}") for i, c in enumerate(conditions)}
# some replace keys may dissapear
# the reason is they are outside the graph
# clone the graph but preserve the equiv mapping
fg = FunctionGraph(
conditions,
outputs,
# clone_get_equiv kwargs
copy_orphans=False,
copy_inputs=False,
memo=equiv,
)
# replace the conditions back
fg_replace = {equiv[c]: c for c in conditions}
# add the replacements on top of input mappings
fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv})
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly
# some replacements may be initially outside the graph
# but later introduced by a replacement
# So far FunctionGraph does these replacements inplace it is thus unsafe
# apply them using fg.replace, it may change the original graph
if strict:
non_fg_replace = {r: v for r, v in replace.items() if r not in equiv}
if non_fg_replace:
raise ValueError(f"Some replacements were not used: {non_fg_replace}")
toposort = fg.toposort()
def toposort_key(fg: FunctionGraph, ts, pair):
key, _ = pair
if key.owner is not None:
return ts.index(key.owner)
else:
if key in fg.variables:
return -1
else:
raise ValueError(f"{key} is not a part of graph")
sorted_replacements = sorted(
tuple(fg_replace.items()),
# sort based on the fg toposort, if a variable has no owner, it goes first
key=partial(toposort_key, fg, toposort),
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)
return list(fg.outputs)
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function, shared from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs from pytensor.graph.basic import graph_inputs
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.tensor import dvector, fvector, vector from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable from tests.graph.utils import MyOp, MyVariable
...@@ -133,3 +133,82 @@ class TestCloneReplace: ...@@ -133,3 +133,82 @@ class TestCloneReplace:
utt.assert_allclose( utt.assert_allclose(
test(x, pt.sum((x + 1) ** 2), mention_y=True), 1.21000003815 test(x, pt.sum((x + 1) ** 2), mention_y=True), 1.21000003815
) )
class TestGraphReplace:
def test_graph_replace(self):
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
w = MyVariable("w")
MyOp("zop")(z)
x2 = MyOp("xop")(x, w)
x2.name = "x2"
y2 = MyOp("yop")(y)
y2.name = "y2"
yc = graph_replace([x2], {x: y2})[0]
assert yc.owner.inputs[0] is y2
# the old reference is kept
assert yc.owner.inputs[1] is w
# test replace itself
yc = graph_replace([x2], {x2: y2})[0]
assert yc is y2
assert yc.owner.inputs[0] is y
assert len(yc.owner.inputs) == 1
# the case where inputs have to be replaced in reverse topological order
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
oc = graph_replace([o], {x: new_x, y2: new_y2})[0]
assert oc.owner.inputs[1] is new_y2
assert oc.owner.inputs[0].owner.inputs[0] is new_x
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
def test_graph_replace_advanced(self):
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
w = MyVariable("w")
z2 = MyOp("zop")(z)
x2 = MyOp("xop")(x, w)
x2.name = "x2"
y2 = MyOp("yop")(y)
y2.name = "y2"
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
new_y21 = MyOp("ny2op")(new_y2)
# now yet another replacement that could only appear after new_y2: z
# show we can do that after the prev clone
# the case where new variable is referenced during the replacements
new_y21 = MyOp("ny2op")(new_y2)
# the reference new_y2: z2 is not a part of the original graph so the replacement is unsafe
oc = graph_replace([o], {x: new_x, y2: new_y21})
oc = graph_replace(oc, {new_y2: z2})[0]
assert oc.owner.inputs[1].owner.inputs[0] is z2
assert oc.owner.inputs[0].owner.inputs[0] is new_x
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
new_z = z.clone(name="z_new")
oc = graph_replace([oc], {z: new_z})[0]
# new reference appear
assert oc.owner.inputs[1].owner.inputs[0] is not z2
assert oc.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[0] is new_x
assert oc.owner.inputs[0].owner.inputs[1] is w
def test_graph_replace_disconnected(self):
x = MyVariable("x")
fake = MyOp("fake")(x)
o = MyOp("o")(x)
oc = graph_replace([o], {fake: x.clone()}, strict=False)
assert oc[0] is o
with pytest.raises(ValueError, match="Some replacements were not used"):
oc = graph_replace([o], {fake: x.clone()}, strict=True)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论