提交 ae403379 authored 作者: Frederic's avatar Frederic

pep8

上级 ff12f37d
...@@ -32,10 +32,11 @@ class MyType(Type): ...@@ -32,10 +32,11 @@ class MyType(Type):
def MyVariable(name): def MyVariable(name):
return Variable(MyType(), None, None, name = name) return Variable(MyType(), None, None, name=name)
def MyConstant(data): def MyConstant(data):
return graph.Constant(MyType(), data = data) return graph.Constant(MyType(), data=data)
class MyOp(Op): class MyOp(Op):
...@@ -74,13 +75,13 @@ class MyOp(Op): ...@@ -74,13 +75,13 @@ class MyOp(Op):
sigmoid = MyOp(1, 'Sigmoid') sigmoid = MyOp(1, 'Sigmoid')
transpose_view = MyOp(1, 'TransposeView', vmap = {0: [0]}) transpose_view = MyOp(1, 'TransposeView', vmap={0: [0]})
add = MyOp(2, 'Add') add = MyOp(2, 'Add')
add_in_place = MyOp(2, 'AddInPlace', dmap = {0: [0]}) add_in_place = MyOp(2, 'AddInPlace', dmap={0: [0]})
add_in_place_2 = MyOp(2, 'AddInPlace', dmap = {0: [0]}, add_in_place_2 = MyOp(2, 'AddInPlace', dmap={0: [0]},
destroyhandler_tolerate_same = [(0, 1)]) destroyhandler_tolerate_same=[(0, 1)])
add_in_place_3 = MyOp(2, 'AddInPlace', dmap = {0: [0]}, add_in_place_3 = MyOp(2, 'AddInPlace', dmap={0: [0]},
destroyhandler_tolerate_aliased = [(0, 1)]) destroyhandler_tolerate_aliased=[(0, 1)])
dot = MyOp(2, 'Dot') dot = MyOp(2, 'Dot')
...@@ -91,7 +92,7 @@ def inputs(): ...@@ -91,7 +92,7 @@ def inputs():
return x, y, z return x, y, z
_Env = Env _Env = Env
def Env(inputs, outputs, validate = True): def Env(inputs, outputs, validate=True):
e = _Env(inputs, outputs) e = _Env(inputs, outputs)
e.attach_feature(destroyhandler.DestroyHandler()) e.attach_feature(destroyhandler.DestroyHandler())
e.attach_feature(ReplaceValidate()) e.attach_feature(ReplaceValidate())
...@@ -101,9 +102,11 @@ def Env(inputs, outputs, validate = True): ...@@ -101,9 +102,11 @@ def Env(inputs, outputs, validate = True):
class FailureWatch: class FailureWatch:
# when passed to OpSubOptimizer or PatternOptimizer, counts the number of failures # when passed to OpSubOptimizer or PatternOptimizer, counts the
# number of failures
def __init__(self): def __init__(self):
self.failures = 0 self.failures = 0
def __call__(self, exc, nav, pairs, lopt): def __call__(self, exc, nav, pairs, lopt):
assert isinstance(exc, InconsistencyError) assert isinstance(exc, InconsistencyError)
self.failures += 1 self.failures += 1
...@@ -118,6 +121,7 @@ def consistent(g): ...@@ -118,6 +121,7 @@ def consistent(g):
raise raise
#print "Test OK" #print "Test OK"
def inconsistent(g): def inconsistent(g):
#print "Testing NOT consistent:", g #print "Testing NOT consistent:", g
try: try:
...@@ -127,24 +131,23 @@ def inconsistent(g): ...@@ -127,24 +131,23 @@ def inconsistent(g):
raise raise
#print "Test OK" #print "Test OK"
################# #################
# Test protocol # # Test protocol #
################# #################
def test_misc(): def test_misc():
x, y, z = inputs() x, y, z = inputs()
e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = Env([x,y,z], [e]) g = Env([x, y, z], [e])
consistent(g) consistent(g)
chk = g.checkpoint() chk = g.checkpoint()
PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g) PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
new_e = add(x,y) new_e = add(x, y)
g.replace_validate(x, new_e) g.replace_validate(x, new_e)
assert str(g) == "[Add(x, y)]" assert str(g) == "[Add(x, y)]"
g.replace(new_e, dot(add_in_place(x,y), transpose_view(x))) g.replace(new_e, dot(add_in_place(x, y), transpose_view(x)))
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
inconsistent(g) inconsistent(g)
g.revert(chk) g.revert(chk)
...@@ -152,12 +155,11 @@ def test_misc(): ...@@ -152,12 +155,11 @@ def test_misc():
assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]" assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]"
###################### ######################
# Test protocol skip # # Test protocol skip #
###################### ######################
def test_aliased_inputs_replacement(): def test_aliased_inputs_replacement():
x, y, z = inputs() x, y, z = inputs()
tv = transpose_view(x) tv = transpose_view(x)
...@@ -175,32 +177,36 @@ def test_aliased_inputs_replacement(): ...@@ -175,32 +177,36 @@ def test_aliased_inputs_replacement():
g.replace(tv, sx) g.replace(tv, sx)
consistent(g) consistent(g)
def test_indestructible(): def test_indestructible():
x, y, z = inputs() x, y, z = inputs()
x.tag.indestructible = True x.tag.indestructible = True
x = copy(x) x = copy(x)
assert x.tag.indestructible # checking if indestructible survives the copy! # checking if indestructible survives the copy!
assert x.tag.indestructible
e = add_in_place(x, y) e = add_in_place(x, y)
g = Env([x,y,z], [e], False) g = Env([x, y, z], [e], False)
inconsistent(g) inconsistent(g)
g.replace_validate(e, add(x, y)) g.replace_validate(e, add(x, y))
consistent(g) consistent(g)
def test_usage_loop_through_views_2(): def test_usage_loop_through_views_2():
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(transpose_view(sigmoid(x))) e0 = transpose_view(transpose_view(sigmoid(x)))
e = dot(add_in_place(x,y), transpose_view(e0)) e = dot(add_in_place(x, y), transpose_view(e0))
g = Env([x,y,z], [e]) g = Env([x, y, z], [e])
consistent(g) # because sigmoid can do the copy consistent(g) # because sigmoid can do the copy
g.replace(e0, x) g.replace(e0, x)
inconsistent(g) # we cut off the path to the sigmoid inconsistent(g) # we cut off the path to the sigmoid
def test_destroyers_loop(): def test_destroyers_loop():
# AddInPlace(x, y) and AddInPlace(y, x) should not coexist # AddInPlace(x, y) and AddInPlace(y, x) should not coexist
x, y, z = inputs() x, y, z = inputs()
e1 = add(x, y) e1 = add(x, y)
e2 = add(y, x) e2 = add(y, x)
g = Env([x,y,z], [e1, e2]) g = Env([x, y, z], [e1, e2])
chk = g.checkpoint() chk = g.checkpoint()
consistent(g) consistent(g)
g.replace_validate(e1, add_in_place(x, y)) g.replace_validate(e1, add_in_place(x, y))
...@@ -232,30 +238,35 @@ def test_aliased_inputs(): ...@@ -232,30 +238,35 @@ def test_aliased_inputs():
g = Env([x], [e], False) g = Env([x], [e], False)
inconsistent(g) inconsistent(g)
def test_aliased_inputs2(): def test_aliased_inputs2():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place(x, transpose_view(x)) e = add_in_place(x, transpose_view(x))
g = Env([x], [e], False) g = Env([x], [e], False)
inconsistent(g) inconsistent(g)
def test_aliased_inputs_tolerate(): def test_aliased_inputs_tolerate():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_2(x, x) e = add_in_place_2(x, x)
g = Env([x], [e], False) g = Env([x], [e], False)
consistent(g) consistent(g)
def test_aliased_inputs_tolerate2(): def test_aliased_inputs_tolerate2():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_2(x, transpose_view(x)) e = add_in_place_2(x, transpose_view(x))
g = Env([x], [e], False) g = Env([x], [e], False)
inconsistent(g) inconsistent(g)
def test_same_aliased_inputs_ignored(): def test_same_aliased_inputs_ignored():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_3(x, x) e = add_in_place_3(x, x)
g = Env([x], [e], False) g = Env([x], [e], False)
consistent(g) consistent(g)
def test_different_aliased_inputs_ignored(): def test_different_aliased_inputs_ignored():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_3(x, transpose_view(x)) e = add_in_place_3(x, transpose_view(x))
...@@ -271,16 +282,17 @@ def test_indestructible_through_views(): ...@@ -271,16 +282,17 @@ def test_indestructible_through_views():
x.tag.indestructible = True x.tag.indestructible = True
tv = transpose_view(x) tv = transpose_view(x)
e = add_in_place(tv, y) e = add_in_place(tv, y)
g = Env([x,y,z], [e], False) g = Env([x, y, z], [e], False)
inconsistent(g) inconsistent(g)
g.replace_validate(tv, sigmoid(x)) g.replace_validate(tv, sigmoid(x))
consistent(g) consistent(g)
def test_indirect(): def test_indirect():
x, y, z = inputs() x, y, z = inputs()
e0 = add_in_place(x, y) e0 = add_in_place(x, y)
e = dot(sigmoid(e0), transpose_view(x)) e = dot(sigmoid(e0), transpose_view(x))
g = Env([x,y,z], [e], False) g = Env([x, y, z], [e], False)
inconsistent(g) inconsistent(g)
new_e0 = add(x, y) new_e0 = add(x, y)
g.replace(e0, new_e0) g.replace(e0, new_e0)
...@@ -288,53 +300,63 @@ def test_indirect(): ...@@ -288,53 +300,63 @@ def test_indirect():
g.replace(new_e0, add_in_place(x, y)) g.replace(new_e0, add_in_place(x, y))
inconsistent(g) inconsistent(g)
def test_indirect_2(): def test_indirect_2():
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(x) e0 = transpose_view(x)
e = dot(sigmoid(add_in_place(x, y)), e0) e = dot(sigmoid(add_in_place(x, y)), e0)
g = Env([x,y,z], [e], False) g = Env([x, y, z], [e], False)
inconsistent(g) inconsistent(g)
new_e0 = add(e0, y) new_e0 = add(e0, y)
g.replace(e0, new_e0) g.replace(e0, new_e0)
consistent(g) consistent(g)
def test_long_destroyers_loop(): def test_long_destroyers_loop():
x, y, z = inputs() x, y, z = inputs()
e = dot(dot(add_in_place(x,y), add_in_place(y,z)), add(z,x)) e = dot(dot(add_in_place(x, y),
g = Env([x,y,z], [e]) add_in_place(y, z)),
add(z, x))
g = Env([x, y, z], [e])
consistent(g) consistent(g)
OpSubOptimizer(add, add_in_place).optimize(g) OpSubOptimizer(add, add_in_place).optimize(g)
consistent(g) consistent(g)
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that! # we don't want to see that!
e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x)) assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]"
e2 = dot(dot(add_in_place(x, y),
add_in_place(y, z)),
add_in_place(z, x))
try: try:
g2 = Env(*graph.clone([x,y,z], [e2])) g2 = Env(*graph.clone([x, y, z], [e2]))
raise Exception("Shouldn't have reached this point.") raise Exception("Shouldn't have reached this point.")
except InconsistencyError: except InconsistencyError:
pass pass
def test_misc_2(): def test_misc_2():
x, y, z = inputs() x, y, z = inputs()
tv = transpose_view(x) tv = transpose_view(x)
e = add_in_place(x, tv) e = add_in_place(x, tv)
g = Env([x,y], [e], False) g = Env([x, y], [e], False)
inconsistent(g) inconsistent(g)
g.replace(tv, x) g.replace(tv, x)
inconsistent(g) inconsistent(g)
def test_multi_destroyers(): def test_multi_destroyers():
x, y, z = inputs() x, y, z = inputs()
e = add(add_in_place(x, y), add_in_place(x, y)) e = add(add_in_place(x, y), add_in_place(x, y))
try: try:
g = Env([x,y,z], [e]) g = Env([x, y, z], [e])
raise Exception("Shouldn't have reached this point.") raise Exception("Shouldn't have reached this point.")
except InconsistencyError, e: except InconsistencyError, e:
pass pass
def test_multi_destroyers_through_views(): def test_multi_destroyers_through_views():
x, y, z = inputs() x, y, z = inputs()
e = dot(add(transpose_view(z), y), add(z, x)) e = dot(add(transpose_view(z), y), add(z, x))
g = Env([x,y,z], [e]) g = Env([x, y, z], [e])
consistent(g) consistent(g)
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(add, add_in_place, fail).optimize(g) OpSubOptimizer(add, add_in_place, fail).optimize(g)
...@@ -348,54 +370,59 @@ def test_repair_destroy_path(): ...@@ -348,54 +370,59 @@ def test_repair_destroy_path():
e2 = transpose_view(transpose_view(e1)) e2 = transpose_view(transpose_view(e1))
e3 = add_in_place(e2, y) e3 = add_in_place(e2, y)
e4 = add_in_place(e1, z) e4 = add_in_place(e1, z)
g = Env([x,y,z], [e3, e4], False) g = Env([x, y, z], [e3, e4], False)
inconsistent(g) inconsistent(g)
g.replace(e2, transpose_view(x)) g.replace(e2, transpose_view(x))
inconsistent(g) inconsistent(g)
def test_usage_loop(): def test_usage_loop():
x, y, z = inputs() x, y, z = inputs()
g = Env([x,y,z], [dot(add_in_place(x, z), x)], False) g = Env([x, y, z], [dot(add_in_place(x, z), x)], False)
inconsistent(g) inconsistent(g)
OpSubOptimizer(add_in_place, add).optimize(g) # replace add_in_place with add # replace add_in_place with add
OpSubOptimizer(add_in_place, add).optimize(g)
consistent(g) consistent(g)
def test_usage_loop_through_views(): def test_usage_loop_through_views():
x, y, z = inputs() x, y, z = inputs()
aip = add_in_place(x, y) aip = add_in_place(x, y)
e = dot(aip, transpose_view(x)) e = dot(aip, transpose_view(x))
g = Env([x,y,z], [e], False) g = Env([x, y, z], [e], False)
inconsistent(g) inconsistent(g)
g.replace_validate(aip, add(x, z)) g.replace_validate(aip, add(x, z))
consistent(g) consistent(g)
def test_usage_loop_insert_views(): def test_usage_loop_insert_views():
x, y, z = inputs() x, y, z = inputs()
e = dot(add_in_place(x, add(y, z)), sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x)))))) e = dot(add_in_place(x, add(y, z)),
g = Env([x,y,z], [e]) sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x))))))
g = Env([x, y, z], [e])
consistent(g) consistent(g)
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g) OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g)
consistent(g) consistent(g)
assert fail.failures == 1 # it must keep one sigmoid in the long sigmoid chain # it must keep one sigmoid in the long sigmoid chain
assert fail.failures == 1
def test_value_repl(): def test_value_repl():
x, y, z = inputs() x, y, z = inputs()
sy = sigmoid(y) sy = sigmoid(y)
e = add_in_place(x, sy) e = add_in_place(x, sy)
g = Env([x,y], [e], False) g = Env([x, y], [e], False)
consistent(g) consistent(g)
g.replace(sy, MyConstant("abc")) g.replace(sy, MyConstant("abc"))
consistent(g) consistent(g)
def test_value_repl_2(): def test_value_repl_2():
x, y, z = inputs() x, y, z = inputs()
sy = sigmoid(y) sy = sigmoid(y)
e = add_in_place(x, sy) e = add_in_place(x, sy)
g = Env([x,y], [e], False) g = Env([x, y], [e], False)
consistent(g) consistent(g)
g.replace(sy, transpose_view(MyConstant("abc"))) g.replace(sy, transpose_view(MyConstant("abc")))
consistent(g) consistent(g)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论