提交 5946b4de authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve readability of graph_replace tests

上级 4b812709
...@@ -144,92 +144,94 @@ class TestCloneReplace: ...@@ -144,92 +144,94 @@ class TestCloneReplace:
class TestGraphReplace: class TestGraphReplace:
def test_graph_replace(self): def test_graph_replace(self):
op = MyOp("op")
x = MyVariable("x") x = MyVariable("x")
y = MyVariable("y") y = MyVariable("y")
z = MyVariable("z") z = MyVariable("w")
w = MyVariable("w") out = op(x, z)
MyOp("zop")(z)
x2 = MyOp("xop")(x, w) new_x = op(y)
x2.name = "x2" new_out = graph_replace([out], {x: new_x})[0]
y2 = MyOp("yop")(y) assert new_out.owner.inputs[0] is new_x
y2.name = "y2"
yc = graph_replace([x2], {x: y2})[0]
assert yc.owner.inputs[0] is y2
# the old reference is kept # the old reference is kept
assert yc.owner.inputs[1] is w assert new_out.owner.inputs[1] is z
# test replace itself # test replace itself
yc = graph_replace([x2], {x2: y2})[0] new_out = graph_replace([out], {out: new_x})[0]
assert yc is y2 assert new_out is new_x
assert yc.owner.inputs[0] is y assert new_out.owner.inputs[0] is y
assert len(yc.owner.inputs) == 1 assert len(new_out.owner.inputs) == 1
# the case where inputs have to be replaced in reverse topological order # the case where inputs have to be replaced in reverse topological order
o = MyOp("xyop")(x2, y2) out2 = op(out, new_x)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
oc = graph_replace([o], {x: new_x, y2: new_y2})[0] new_x2 = x.clone(name="new_x")
assert oc.owner.inputs[1] is new_y2 new_x22 = new_x.clone(name="new_x2")
assert oc.owner.inputs[0].owner.inputs[0] is new_x new_out2 = graph_replace([out2], {x: new_x2, new_x: new_x22})[0]
assert new_out2.owner.inputs[1] is new_x22
assert new_out2.owner.inputs[0].owner.inputs[0] is new_x2
# the old reference is still kept # the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w assert new_out2.owner.inputs[0].owner.inputs[1] is z
def test_non_list_input(self): def test_non_list_input(self):
op = MyOp("op")
x = MyVariable("x") x = MyVariable("x")
y = MyVariable("y") y = MyVariable("y")
o = MyOp("xyop")(x, y) out = op(x, y)
new_x = x.clone(name="x_new")
new_y = y.clone(name="y2_new") new_x = x.clone(name="new_x")
new_y = y.clone(name="new_y")
# test non list inputs as well # test non list inputs as well
oc = graph_replace(o, {x: new_x, y: new_y}) oc = graph_replace(out, {x: new_x, y: new_y})
assert oc.owner.inputs[1] is new_y assert oc.owner.inputs[1] is new_y
assert oc.owner.inputs[0] is new_x assert oc.owner.inputs[0] is new_x
def test_graph_replace_advanced(self): def test_graph_replace_advanced(self):
op = MyOp("op")
x = MyVariable("x") x = MyVariable("x")
y = MyVariable("y") y = MyVariable("y")
z = MyVariable("z") z = MyVariable("z")
w = MyVariable("w") w = MyVariable("w")
z2 = MyOp("zop")(z)
x2 = MyOp("xop")(x, w) z_op = op(z)
x2.name = "x2" xw_op = op(x, w)
y2 = MyOp("yop")(y) y_op = op(y)
y2.name = "y2" out = op(xw_op, y_op)
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new") new_x = x.clone(name="new_x")
new_y2 = y2.clone(name="y2_new") new_yop = y_op.clone(name="new_yop")
new_y21 = MyOp("ny2op")(new_y2)
# now yet another replacement that could only appear after new_y2: z # now yet another replacement that could only appear after new_y2: z
# show we can do that after the prev clone # show we can do that after the prev clone
# the case where new variable is referenced during the replacements # the case where new variable is referenced during the replacements
new_y21 = MyOp("ny2op")(new_y2) new_yop_op = op(new_yop)
# the reference new_y2: z2 is not a part of the original graph so the replacement is unsafe # the reference new_yop: z_op is not a part of the original graph so the replacement is unsafe
oc = graph_replace([o], {x: new_x, y2: new_y21}) new_out = graph_replace([out], {x: new_x, y_op: new_yop_op})
oc = graph_replace(oc, {new_y2: z2})[0] new_out = graph_replace(new_out, {new_yop: z_op})[0]
assert oc.owner.inputs[1].owner.inputs[0] is z2 assert new_out.owner.inputs[1].owner.inputs[0] is z_op
assert oc.owner.inputs[0].owner.inputs[0] is new_x assert new_out.owner.inputs[0].owner.inputs[0] is new_x
# the old reference is still kept # the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w assert new_out.owner.inputs[0].owner.inputs[1] is w
new_z = z.clone(name="z_new") new_z = z.clone(name="new_z")
oc = graph_replace([oc], {z: new_z})[0] new_out = graph_replace([new_out], {z: new_z})[0]
# new reference appear # new reference appear
assert oc.owner.inputs[1].owner.inputs[0] is not z2 assert new_out.owner.inputs[1].owner.inputs[0] is not z_op
assert oc.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z assert new_out.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z
# the old reference is still kept # the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[0] is new_x assert new_out.owner.inputs[0].owner.inputs[0] is new_x
assert oc.owner.inputs[0].owner.inputs[1] is w assert new_out.owner.inputs[0].owner.inputs[1] is w
def test_graph_replace_disconnected(self): def test_graph_replace_disconnected(self):
op = MyOp("op")
fake_op = MyOp("fake_op")
x = MyVariable("x") x = MyVariable("x")
fake = MyOp("fake")(x) fake = fake_op(x)
o = MyOp("o")(x) out = op(x)
oc = graph_replace([o], {fake: x.clone()}, strict=False) [new_out] = graph_replace([out], {fake: x.clone()}, strict=False)
assert oc[0] is o assert new_out is out
with pytest.raises(ValueError, match="Some replacements were not used"): with pytest.raises(ValueError, match="Some replacements were not used"):
oc = graph_replace([o], {fake: x.clone()}, strict=True) graph_replace([out], {fake: x.clone()}, strict=True)
class TestVectorizeGraph: class TestVectorizeGraph:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论