提交 5946b4de authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve readability of graph_replace tests

上级 4b812709
......@@ -144,92 +144,94 @@ class TestCloneReplace:
class TestGraphReplace:
def test_graph_replace(self):
op = MyOp("op")
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
w = MyVariable("w")
MyOp("zop")(z)
x2 = MyOp("xop")(x, w)
x2.name = "x2"
y2 = MyOp("yop")(y)
y2.name = "y2"
yc = graph_replace([x2], {x: y2})[0]
assert yc.owner.inputs[0] is y2
z = MyVariable("w")
out = op(x, z)
new_x = op(y)
new_out = graph_replace([out], {x: new_x})[0]
assert new_out.owner.inputs[0] is new_x
# the old reference is kept
assert yc.owner.inputs[1] is w
assert new_out.owner.inputs[1] is z
# test replace itself
yc = graph_replace([x2], {x2: y2})[0]
assert yc is y2
assert yc.owner.inputs[0] is y
assert len(yc.owner.inputs) == 1
new_out = graph_replace([out], {out: new_x})[0]
assert new_out is new_x
assert new_out.owner.inputs[0] is y
assert len(new_out.owner.inputs) == 1
# the case where inputs have to be replaced in reverse topological order
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
out2 = op(out, new_x)
oc = graph_replace([o], {x: new_x, y2: new_y2})[0]
assert oc.owner.inputs[1] is new_y2
assert oc.owner.inputs[0].owner.inputs[0] is new_x
new_x2 = x.clone(name="new_x")
new_x22 = new_x.clone(name="new_x2")
new_out2 = graph_replace([out2], {x: new_x2, new_x: new_x22})[0]
assert new_out2.owner.inputs[1] is new_x22
assert new_out2.owner.inputs[0].owner.inputs[0] is new_x2
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
assert new_out2.owner.inputs[0].owner.inputs[1] is z
def test_non_list_input(self):
op = MyOp("op")
x = MyVariable("x")
y = MyVariable("y")
o = MyOp("xyop")(x, y)
new_x = x.clone(name="x_new")
new_y = y.clone(name="y2_new")
out = op(x, y)
new_x = x.clone(name="new_x")
new_y = y.clone(name="new_y")
# test non list inputs as well
oc = graph_replace(o, {x: new_x, y: new_y})
oc = graph_replace(out, {x: new_x, y: new_y})
assert oc.owner.inputs[1] is new_y
assert oc.owner.inputs[0] is new_x
def test_graph_replace_advanced(self):
op = MyOp("op")
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
w = MyVariable("w")
z2 = MyOp("zop")(z)
x2 = MyOp("xop")(x, w)
x2.name = "x2"
y2 = MyOp("yop")(y)
y2.name = "y2"
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
new_y21 = MyOp("ny2op")(new_y2)
z_op = op(z)
xw_op = op(x, w)
y_op = op(y)
out = op(xw_op, y_op)
new_x = x.clone(name="new_x")
new_yop = y_op.clone(name="new_yop")
# now yet another replacement that could only appear after new_y2: z
# show we can do that after the prev clone
# the case where new variable is referenced during the replacements
new_y21 = MyOp("ny2op")(new_y2)
# the reference new_y2: z2 is not a part of the original graph so the replacement is unsafe
oc = graph_replace([o], {x: new_x, y2: new_y21})
oc = graph_replace(oc, {new_y2: z2})[0]
assert oc.owner.inputs[1].owner.inputs[0] is z2
assert oc.owner.inputs[0].owner.inputs[0] is new_x
new_yop_op = op(new_yop)
# the reference new_yop: z_op is not a part of the original graph so the replacement is unsafe
new_out = graph_replace([out], {x: new_x, y_op: new_yop_op})
new_out = graph_replace(new_out, {new_yop: z_op})[0]
assert new_out.owner.inputs[1].owner.inputs[0] is z_op
assert new_out.owner.inputs[0].owner.inputs[0] is new_x
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
assert new_out.owner.inputs[0].owner.inputs[1] is w
new_z = z.clone(name="z_new")
oc = graph_replace([oc], {z: new_z})[0]
new_z = z.clone(name="new_z")
new_out = graph_replace([new_out], {z: new_z})[0]
# new reference appear
assert oc.owner.inputs[1].owner.inputs[0] is not z2
assert oc.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z
assert new_out.owner.inputs[1].owner.inputs[0] is not z_op
assert new_out.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[0] is new_x
assert oc.owner.inputs[0].owner.inputs[1] is w
assert new_out.owner.inputs[0].owner.inputs[0] is new_x
assert new_out.owner.inputs[0].owner.inputs[1] is w
def test_graph_replace_disconnected(self):
op = MyOp("op")
fake_op = MyOp("fake_op")
x = MyVariable("x")
fake = MyOp("fake")(x)
o = MyOp("o")(x)
oc = graph_replace([o], {fake: x.clone()}, strict=False)
assert oc[0] is o
fake = fake_op(x)
out = op(x)
[new_out] = graph_replace([out], {fake: x.clone()}, strict=False)
assert new_out is out
with pytest.raises(ValueError, match="Some replacements were not used"):
oc = graph_replace([o], {fake: x.clone()}, strict=True)
graph_replace([out], {fake: x.clone()}, strict=True)
class TestVectorizeGraph:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论