提交 6ca2eecf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename test helper function Env to create_fgraph in tests.graph.test_destroyhandler

上级 43695bc7
...@@ -116,7 +116,7 @@ def inputs(): ...@@ -116,7 +116,7 @@ def inputs():
return x, y, z return x, y, z
def Env(inputs, outputs, validate=True): def create_fgraph(inputs, outputs, validate=True):
e = FunctionGraph(inputs, outputs, clone=False) e = FunctionGraph(inputs, outputs, clone=False)
e.attach_feature(DestroyHandler()) e.attach_feature(DestroyHandler())
e.attach_feature(ReplaceValidate()) e.attach_feature(ReplaceValidate())
...@@ -144,7 +144,7 @@ class FailureWatch: ...@@ -144,7 +144,7 @@ class FailureWatch:
def test_misc(): def test_misc():
x, y, z = inputs() x, y, z = inputs()
e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = Env([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
PatternOptimizer((transpose_view, (transpose_view, "x")), "x").optimize(g) PatternOptimizer((transpose_view, (transpose_view, "x")), "x").optimize(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
...@@ -168,7 +168,7 @@ def test_aliased_inputs_replacement(): ...@@ -168,7 +168,7 @@ def test_aliased_inputs_replacement():
tvv = transpose_view(tv) tvv = transpose_view(tv)
sx = sigmoid(x) sx = sigmoid(x)
e = add_in_place(x, tv) e = add_in_place(x, tv)
g = Env([x, y], [e], False) g = create_fgraph([x, y], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(tv, sx) g.replace(tv, sx)
assert g.consistent() assert g.consistent()
...@@ -187,7 +187,7 @@ def test_indestructible(): ...@@ -187,7 +187,7 @@ def test_indestructible():
# checking if indestructible survives the copy! # checking if indestructible survives the copy!
assert x.tag.indestructible assert x.tag.indestructible
e = add_in_place(x, y) e = add_in_place(x, y)
g = Env([x, y, z], [e], False) g = create_fgraph([x, y, z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace_validate(e, add(x, y)) g.replace_validate(e, add(x, y))
assert g.consistent() assert g.consistent()
...@@ -198,7 +198,7 @@ def test_usage_loop_through_views_2(): ...@@ -198,7 +198,7 @@ def test_usage_loop_through_views_2():
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(transpose_view(sigmoid(x))) e0 = transpose_view(transpose_view(sigmoid(x)))
e = dot(add_in_place(x, y), transpose_view(e0)) e = dot(add_in_place(x, y), transpose_view(e0))
g = Env([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() # because sigmoid can do the copy assert g.consistent() # because sigmoid can do the copy
g.replace(e0, x) g.replace(e0, x)
assert not g.consistent() # we cut off the path to the sigmoid assert not g.consistent() # we cut off the path to the sigmoid
...@@ -210,7 +210,7 @@ def test_destroyers_loop(): ...@@ -210,7 +210,7 @@ def test_destroyers_loop():
x, y, z = inputs() x, y, z = inputs()
e1 = add(x, y) e1 = add(x, y)
e2 = add(y, x) e2 = add(y, x)
g = Env([x, y, z], [e1, e2]) g = create_fgraph([x, y, z], [e1, e2])
assert g.consistent() assert g.consistent()
g.replace_validate(e1, add_in_place(x, y)) g.replace_validate(e1, add_in_place(x, y))
assert g.consistent() assert g.consistent()
...@@ -221,7 +221,7 @@ def test_destroyers_loop(): ...@@ -221,7 +221,7 @@ def test_destroyers_loop():
x, y, z = inputs() x, y, z = inputs()
e1 = add(x, y) e1 = add(x, y)
e2 = add(y, x) e2 = add(y, x)
g = Env([x, y, z], [e1, e2]) g = create_fgraph([x, y, z], [e1, e2])
assert g.consistent() assert g.consistent()
g.replace_validate(e2, add_in_place(y, x)) g.replace_validate(e2, add_in_place(y, x))
assert g.consistent() assert g.consistent()
...@@ -238,14 +238,14 @@ def test_destroyers_loop(): ...@@ -238,14 +238,14 @@ def test_destroyers_loop():
def test_aliased_inputs(): def test_aliased_inputs():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place(x, x) e = add_in_place(x, x)
g = Env([x], [e], False) g = create_fgraph([x], [e], False)
assert not g.consistent() assert not g.consistent()
def test_aliased_inputs2(): def test_aliased_inputs2():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place(x, transpose_view(x)) e = add_in_place(x, transpose_view(x))
g = Env([x], [e], False) g = create_fgraph([x], [e], False)
assert not g.consistent() assert not g.consistent()
...@@ -253,14 +253,14 @@ def test_aliased_inputs2(): ...@@ -253,14 +253,14 @@ def test_aliased_inputs2():
def test_aliased_inputs_tolerate(): def test_aliased_inputs_tolerate():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_2(x, x) e = add_in_place_2(x, x)
g = Env([x], [e], False) g = create_fgraph([x], [e], False)
assert g.consistent() assert g.consistent()
def test_aliased_inputs_tolerate2(): def test_aliased_inputs_tolerate2():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_2(x, transpose_view(x)) e = add_in_place_2(x, transpose_view(x))
g = Env([x], [e], False) g = create_fgraph([x], [e], False)
assert not g.consistent() assert not g.consistent()
...@@ -268,7 +268,7 @@ def test_aliased_inputs_tolerate2(): ...@@ -268,7 +268,7 @@ def test_aliased_inputs_tolerate2():
def test_same_aliased_inputs_ignored(): def test_same_aliased_inputs_ignored():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_3(x, x) e = add_in_place_3(x, x)
g = Env([x], [e], False) g = create_fgraph([x], [e], False)
assert g.consistent() assert g.consistent()
...@@ -276,7 +276,7 @@ def test_same_aliased_inputs_ignored(): ...@@ -276,7 +276,7 @@ def test_same_aliased_inputs_ignored():
def test_different_aliased_inputs_ignored(): def test_different_aliased_inputs_ignored():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_3(x, transpose_view(x)) e = add_in_place_3(x, transpose_view(x))
g = Env([x], [e], False) g = create_fgraph([x], [e], False)
assert g.consistent() assert g.consistent()
# warning - don't run this because it would produce the wrong answer # warning - don't run this because it would produce the wrong answer
# add_in_place_3 is actually not correct when aliasing of inputs # add_in_place_3 is actually not correct when aliasing of inputs
...@@ -288,7 +288,7 @@ def test_indestructible_through_views(): ...@@ -288,7 +288,7 @@ def test_indestructible_through_views():
x.tag.indestructible = True x.tag.indestructible = True
tv = transpose_view(x) tv = transpose_view(x)
e = add_in_place(tv, y) e = add_in_place(tv, y)
g = Env([x, y, z], [e], False) g = create_fgraph([x, y, z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace_validate(tv, sigmoid(x)) g.replace_validate(tv, sigmoid(x))
assert g.consistent() assert g.consistent()
...@@ -298,7 +298,7 @@ def test_indirect(): ...@@ -298,7 +298,7 @@ def test_indirect():
x, y, z = inputs() x, y, z = inputs()
e0 = add_in_place(x, y) e0 = add_in_place(x, y)
e = dot(sigmoid(e0), transpose_view(x)) e = dot(sigmoid(e0), transpose_view(x))
g = Env([x, y, z], [e], False) g = create_fgraph([x, y, z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = add(x, y) new_e0 = add(x, y)
g.replace(e0, new_e0) g.replace(e0, new_e0)
...@@ -312,7 +312,7 @@ def test_indirect_2(): ...@@ -312,7 +312,7 @@ def test_indirect_2():
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(x) e0 = transpose_view(x)
e = dot(sigmoid(add_in_place(x, y)), e0) e = dot(sigmoid(add_in_place(x, y)), e0)
g = Env([x, y, z], [e], False) g = create_fgraph([x, y, z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = add(e0, y) new_e0 = add(e0, y)
g.replace(e0, new_e0) g.replace(e0, new_e0)
...@@ -323,7 +323,7 @@ def test_indirect_2(): ...@@ -323,7 +323,7 @@ def test_indirect_2():
def test_long_destroyers_loop(): def test_long_destroyers_loop():
x, y, z = inputs() x, y, z = inputs()
e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x)) e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x))
g = Env([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
OpSubOptimizer(add, add_in_place).optimize(g) OpSubOptimizer(add, add_in_place).optimize(g)
assert g.consistent() assert g.consistent()
...@@ -334,14 +334,14 @@ def test_long_destroyers_loop(): ...@@ -334,14 +334,14 @@ def test_long_destroyers_loop():
) )
e2 = dot(dot(add_in_place(x, y), add_in_place(y, z)), add_in_place(z, x)) e2 = dot(dot(add_in_place(x, y), add_in_place(y, z)), add_in_place(z, x))
with pytest.raises(InconsistencyError): with pytest.raises(InconsistencyError):
Env(*clone([x, y, z], [e2])) create_fgraph(*clone([x, y, z], [e2]))
def test_misc_2(): def test_misc_2():
x, y, z = inputs() x, y, z = inputs()
tv = transpose_view(x) tv = transpose_view(x)
e = add_in_place(x, tv) e = add_in_place(x, tv)
g = Env([x, y], [e], False) g = create_fgraph([x, y], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(tv, x) g.replace(tv, x)
assert not g.consistent() assert not g.consistent()
...@@ -351,14 +351,14 @@ def test_multi_destroyers(): ...@@ -351,14 +351,14 @@ def test_multi_destroyers():
x, y, z = inputs() x, y, z = inputs()
e = add(add_in_place(x, y), add_in_place(x, y)) e = add(add_in_place(x, y), add_in_place(x, y))
with pytest.raises(InconsistencyError): with pytest.raises(InconsistencyError):
Env([x, y, z], [e]) create_fgraph([x, y, z], [e])
@assertFailure_fast @assertFailure_fast
def test_multi_destroyers_through_views(): def test_multi_destroyers_through_views():
x, y, z = inputs() x, y, z = inputs()
e = dot(add(transpose_view(z), y), add(z, x)) e = dot(add(transpose_view(z), y), add(z, x))
g = Env([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(add, add_in_place, fail).optimize(g) OpSubOptimizer(add, add_in_place, fail).optimize(g)
...@@ -372,7 +372,7 @@ def test_repair_destroy_path(): ...@@ -372,7 +372,7 @@ def test_repair_destroy_path():
e2 = transpose_view(transpose_view(e1)) e2 = transpose_view(transpose_view(e1))
e3 = add_in_place(e2, y) e3 = add_in_place(e2, y)
e4 = add_in_place(e1, z) e4 = add_in_place(e1, z)
g = Env([x, y, z], [e3, e4], False) g = create_fgraph([x, y, z], [e3, e4], False)
assert not g.consistent() assert not g.consistent()
g.replace(e2, transpose_view(x)) g.replace(e2, transpose_view(x))
assert not g.consistent() assert not g.consistent()
...@@ -380,7 +380,7 @@ def test_repair_destroy_path(): ...@@ -380,7 +380,7 @@ def test_repair_destroy_path():
def test_usage_loop(): def test_usage_loop():
x, y, z = inputs() x, y, z = inputs()
g = Env([x, y, z], [dot(add_in_place(x, z), x)], False) g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent() assert not g.consistent()
# replace add_in_place with add # replace add_in_place with add
OpSubOptimizer(add_in_place, add).optimize(g) OpSubOptimizer(add_in_place, add).optimize(g)
...@@ -391,7 +391,7 @@ def test_usage_loop_through_views(): ...@@ -391,7 +391,7 @@ def test_usage_loop_through_views():
x, y, z = inputs() x, y, z = inputs()
aip = add_in_place(x, y) aip = add_in_place(x, y)
e = dot(aip, transpose_view(x)) e = dot(aip, transpose_view(x))
g = Env([x, y, z], [e], False) g = create_fgraph([x, y, z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace_validate(aip, add(x, z)) g.replace_validate(aip, add(x, z))
assert g.consistent() assert g.consistent()
...@@ -401,7 +401,7 @@ def test_usage_loop_through_views(): ...@@ -401,7 +401,7 @@ def test_usage_loop_through_views():
def test_usage_loop_insert_views(): def test_usage_loop_insert_views():
x, y, z = inputs() x, y, z = inputs()
e = dot(add_in_place(x, add(y, z)), sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x)))))) e = dot(add_in_place(x, add(y, z)), sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x))))))
g = Env([x, y, z], [e]) g = create_fgraph([x, y, z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g) OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g)
...@@ -414,7 +414,7 @@ def test_value_repl(): ...@@ -414,7 +414,7 @@ def test_value_repl():
x, y, z = inputs() x, y, z = inputs()
sy = sigmoid(y) sy = sigmoid(y)
e = add_in_place(x, sy) e = add_in_place(x, sy)
g = Env([x, y], [e], False) g = create_fgraph([x, y], [e], False)
assert g.consistent() assert g.consistent()
g.replace(sy, MyConstant("abc")) g.replace(sy, MyConstant("abc"))
assert g.consistent() assert g.consistent()
...@@ -425,7 +425,7 @@ def test_value_repl_2(): ...@@ -425,7 +425,7 @@ def test_value_repl_2():
x, y, z = inputs() x, y, z = inputs()
sy = sigmoid(y) sy = sigmoid(y)
e = add_in_place(x, sy) e = add_in_place(x, sy)
g = Env([x, y], [e], False) g = create_fgraph([x, y], [e], False)
assert g.consistent() assert g.consistent()
g.replace(sy, transpose_view(MyConstant("abc"))) g.replace(sy, transpose_view(MyConstant("abc")))
assert g.consistent() assert g.consistent()
...@@ -444,7 +444,7 @@ def test_multiple_inplace(): ...@@ -444,7 +444,7 @@ def test_multiple_inplace():
# try to confuse the DestroyHandler: this dot Op can run # try to confuse the DestroyHandler: this dot Op can run
# before multiple and then multiple can still run in-place on y # before multiple and then multiple can still run in-place on y
e_2 = dot(y, y) e_2 = dot(y, y)
g = Env([x, y], [e_1, e_2], False) g = create_fgraph([x, y], [e_1, e_2], False)
assert g.consistent() assert g.consistent()
# try to work in-place on x/0 and y/1 (this should fail) # try to work in-place on x/0 and y/1 (this should fail)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论