提交 bcb1a525 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

documented and added a lot of tests

上级 e2db78da
...@@ -169,6 +169,14 @@ class _test_CLinker(unittest.TestCase): ...@@ -169,6 +169,14 @@ class _test_CLinker(unittest.TestCase):
self.failUnless(fn(2.0, 2.0) == 4) self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
def test_dups_inner(self):
# Testing that duplicates are allowed inside the graph
x, y, z = inputs()
e = add(mul(y, y), add(x, z))
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
class _test_OpWiseCLinker(unittest.TestCase): class _test_OpWiseCLinker(unittest.TestCase):
...@@ -180,9 +188,4 @@ class _test_OpWiseCLinker(unittest.TestCase): ...@@ -180,9 +188,4 @@ class _test_OpWiseCLinker(unittest.TestCase):
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() unittest.main()
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
fn(2.0, 0.0, 2.0)
...@@ -42,39 +42,20 @@ class Dot(MyOp): ...@@ -42,39 +42,20 @@ class Dot(MyOp):
nin = 2 nin = 2
dtv_elim = PatternOptimizer((TransposeView, (TransposeView, 'x')), 'x') # dtv_elim = PatternOptimizer((TransposeView, (TransposeView, 'x')), 'x')
a2i = OpSubOptimizer(Add, AddInPlace) # AddCls = Add
i2a = OpSubOptimizer(AddInPlace, Add) # AddInPlaceCls = AddInPlace
t2s = OpSubOptimizer(TransposeView, Sigmoid) # a2i = OpSubOptimizer(Add, AddInPlace)
s2t = OpSubOptimizer(Sigmoid, TransposeView) # i2a = OpSubOptimizer(AddInPlace, Add)
# t2s = OpSubOptimizer(TransposeView, Sigmoid)
# from constructor import Constructor # s2t = OpSubOptimizer(Sigmoid, TransposeView)
# from allocators import BuildAllocator
# c = Constructor(BuildAllocator)
# c.update(globals())
# globals().update(c)
# def inputs():
# x = (MyResult('x'))
# y = (MyResult('y'))
# z = (MyResult('z'))
# return x, y, z
# def env(inputs, outputs, validate = True):
# return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
import modes import modes
modes.make_constructors(globals(), name_filter = lambda x:x) modes.make_constructors(globals()) #, name_filter = lambda x:x)
# def inputs():
# x = modes.BuildMode(MyResult('x'))
# y = modes.BuildMode(MyResult('y'))
# z = modes.BuildMode(MyResult('z'))
# return x, y, z
def inputs(): def inputs():
x = modes.build(MyResult('x')) x = modes.build(MyResult('x'))
...@@ -83,153 +64,181 @@ def inputs(): ...@@ -83,153 +64,181 @@ def inputs():
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): def env(inputs, outputs, validate = True):
inputs = [input for input in inputs]
outputs = [output for output in outputs]
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate) return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
class FailureWatch:
# when passed to OpSubOptimizer or PatternOptimizer, counts the number of failures
def __init__(self):
self.failures = 0
def __call__(self, op1, op2, exception):
assert isinstance(exception, InconsistencyError)
self.failures += 1
class _test_all(unittest.TestCase): class _test_all(unittest.TestCase):
def test_0(self): def test_multi_destroyers(self):
x, y, z = inputs() x, y, z = inputs()
e = Add(AddInPlace(x, y), AddInPlace(x, y)) e = add(add_in_place(x, y), add_in_place(x, y))
try: try:
g = env([x,y,z], [e]) g = env([x,y,z], [e])
self.fail()
except InconsistencyError, e: except InconsistencyError, e:
pass pass
else:
raise Exception("Expected an InconsistencyError")
def test_1(self):
# the loop is needed because a2i will optimize in a random order and sometimes
# only one of them fails
for i in xrange(100):
x, y, z = inputs()
e = Add(Add(x, y), Add(y, x))
g = env([x,y,z], [e])
assert g.consistent()
a2i.optimize(g)
assert g.consistent()
assert str(g) != "[AddInPlace(AddInPlace(x, y), AddInPlace(y, x))]"
def test_2(self): def test_multi_destroyers_through_views(self):
x, y, z = inputs()
e = dot(add(transpose_view(z), y), add(z, x))
g = env([x,y,z], [e])
assert g.consistent()
fail = FailureWatch()
OpSubOptimizer(Add, AddInPlace, fail).optimize(g)
assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once
def test_destroyers_loop(self):
# AddInPlace(x, y) and AddInPlace(y, x) should not coexist
x, y, z = inputs() x, y, z = inputs()
g = env([x,y,z], [Dot(AddInPlace(x, z), x)], False) e1 = add(x, y)
e2 = add(y, x)
g = env([x,y,z], [e1, e2])
chk = g.checkpoint()
assert g.consistent()
g.replace(e1, add_in_place(x, y))
assert g.consistent()
try:
g.replace(e2, add_in_place(y, x))
self.fail()
except InconsistencyError:
pass
assert g.consistent()
g.revert(chk)
g.replace(e2, add_in_place(y, x))
assert g.consistent()
try:
g.replace(e1, add_in_place(x, y))
self.fail()
except InconsistencyError:
pass
assert g.consistent()
def test_long_destroyers_loop(self):
x, y, z = inputs()
e = dot(dot(add_in_place(x,y), add_in_place(y,z)), add(z,x))
g = env([x,y,z], [e])
assert g.consistent()
OpSubOptimizer(Add, AddInPlace).optimize(g)
assert g.consistent()
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that!
e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x))
try:
g2 = env([x,y,z], [e2])
self.fail()
except InconsistencyError:
pass
def test_usage_loop(self):
x, y, z = inputs()
g = env([x,y,z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent() assert not g.consistent()
i2a.optimize(g) OpSubOptimizer(AddInPlace, Add).optimize(g) # replace AddInPlace with Add
assert g.consistent() assert g.consistent()
def test_3(self):
for i in xrange(100):
x, y, z = inputs()
e = Dot(Add(TransposeView(z), y), Add(z, x))
g = env([x,y,z], [e])
assert g.consistent()
a2i.optimize(g)
assert g.consistent()
assert str(g) != "[Dot(AddInPlace(TransposeView(z), y), AddInPlace(z, x))]"
def test_4(self): def test_usage_loop_through_views(self):
x, y, z = inputs() x, y, z = inputs()
e = Dot(AddInPlace(x,y), TransposeView(x)) aip = add_in_place(x, y)
e = dot(aip, transpose_view(x))
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e.owner.inputs[1], Add(x,z)) g.replace(aip, add(x, z))
assert g.consistent() assert g.consistent()
def test_5(self): def test_usage_loop_through_views_2(self):
x, y, z = inputs() x, y, z = inputs()
e = Dot(AddInPlace(x,y), TransposeView(TransposeView(TransposeView(TransposeView(Sigmoid(x)))))) e0 = transpose_view(transpose_view(transpose_view(sigmoid(x))))
e = dot(add_in_place(x,y), transpose_view(e0))
g = env([x,y,z], [e]) g = env([x,y,z], [e])
assert g.consistent() assert g.consistent() # because sigmoid can do the copy
g.replace(e.owner.inputs[1].owner.inputs[0], x, False) g.replace(e0, x, False)
assert not g.consistent() assert not g.consistent() # we cut off the path to the sigmoid
def test_6(self): def test_usage_loop_insert_views(self):
for i in xrange(100): x, y, z = inputs()
x, y, z = inputs() e = dot(add_in_place(x, add(y, z)), sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x))))))
e = Dot(AddInPlace(x,Sigmoid(y)), Sigmoid(Sigmoid(Sigmoid(Sigmoid(Sigmoid(x)))))) g = env([x,y,z], [e])
g = env([x,y,z], [e]) assert g.consistent()
assert g.consistent() fail = FailureWatch()
s2t.optimize(g) OpSubOptimizer(Sigmoid, TransposeView, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert str(g) != "[Dot(AddInPlace(x,TransposeView(y)), TransposeView(TransposeView(TransposeView(TransposeView(TransposeView(x))))))]" assert fail.failures == 1 # it must keep one sigmoid in the long sigmoid chain
def test_7(self): def test_misc(self):
x, y, z = inputs() x, y, z = inputs()
e = TransposeView(TransposeView(TransposeView(TransposeView(x)))) e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = env([x,y,z], [e]) g = env([x,y,z], [e])
assert g.consistent() assert g.consistent()
chk = g.checkpoint() chk = g.checkpoint()
dtv_elim.optimize(g) PatternOptimizer((TransposeView, (TransposeView, 'x')), 'x').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
g.replace(g.equiv(e), Add(x,y)) g.replace(g.equiv(e), add(x,y))
assert str(g) == "[Add(x, y)]" assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e), Dot(AddInPlace(x,y), TransposeView(x)), False) g.replace(g.equiv(e), dot(add_in_place(x,y), transpose_view(x)), False)
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert not g.consistent() assert not g.consistent()
g.revert(chk) g.revert(chk)
assert g.consistent() assert g.consistent()
assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]" assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]"
def test_8(self): def test_indestructible(self):
x, y, z = inputs()
e = Dot(Dot(AddInPlace(x,y), AddInPlace(y,z)), Add(z,x))
g = env([x,y,z], [e])
assert g.consistent()
a2i.optimize(g)
assert g.consistent()
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that!
def test_9(self):
x, y, z = inputs() x, y, z = inputs()
x.indestructible = True x.indestructible = True
e = AddInPlace(x, y) e = add_in_place(x, y)
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e, Add(x, y)) g.replace(e, add(x, y))
assert g.consistent() assert g.consistent()
def test_10(self): def test_indestructible_through_views(self):
x, y, z = inputs() x, y, z = inputs()
x.indestructible = True x.indestructible = True
tv = TransposeView(x) tv = transpose_view(x)
e = AddInPlace(tv, y) e = add_in_place(tv, y)
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(tv, Sigmoid(x)) g.replace(tv, sigmoid(x))
assert g.consistent() assert g.consistent()
def test_11(self): def test_repair_destroy_path(self):
x, y, z = inputs() x, y, z = inputs()
e1 = TransposeView(TransposeView(x)) e1 = transpose_view(transpose_view(x))
e2 = TransposeView(TransposeView(e1)) e2 = transpose_view(transpose_view(e1))
e3 = AddInPlace(e2, y) e3 = add_in_place(e2, y)
e4 = AddInPlace(e1, z) e4 = add_in_place(e1, z)
g = env([x,y,z], [e3, e4], False) g = env([x,y,z], [e3, e4], False)
assert not g.consistent() assert not g.consistent()
g.replace(e2, TransposeView(x), False) g.replace(e2, transpose_view(x), False)
assert not g.consistent() assert not g.consistent()
def test_12(self): def test_indirect(self):
x, y, z = inputs() x, y, z = inputs()
e0 = AddInPlace(x, y) e0 = add_in_place(x, y)
e = Dot(Sigmoid(e0), TransposeView(x)) e = dot(sigmoid(e0), transpose_view(x))
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = Add(x, y) new_e0 = add(x, y)
g.replace(e0, new_e0, False) g.replace(e0, new_e0, False)
assert g.consistent() assert g.consistent()
g.replace(new_e0, AddInPlace(x, y), False) g.replace(new_e0, add_in_place(x, y), False)
assert not g.consistent() assert not g.consistent()
def test_13(self): def test_indirect_2(self):
x, y, z = inputs() x, y, z = inputs()
e0 = TransposeView(x) e0 = transpose_view(x)
e = Dot(Sigmoid(AddInPlace(x, y)), e0) e = dot(sigmoid(add_in_place(x, y)), e0)
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = Add(e0, y) new_e0 = add(e0, y)
g.replace(e0, new_e0, False) g.replace(e0, new_e0, False)
assert g.consistent() assert g.consistent()
......
...@@ -39,37 +39,38 @@ class MyOp(Op): ...@@ -39,37 +39,38 @@ class MyOp(Op):
class _test_inputs(unittest.TestCase): class _test_inputs(unittest.TestCase):
def test_0(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2]) assert inputs(op.outputs) == set([r1, r2])
def test_1(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5]) assert inputs(op2.outputs) == set([r1, r2, r5])
def test_unreached_inputs(self):
class _test_orphans(unittest.TestCase):
def test_0(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
try: try:
# function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True) ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
self.fail() self.fail()
except Exception, e: except Exception, e:
if e[0] is results_and_orphans.E_unreached: if e[0] is results_and_orphans.E_unreached:
return return
raise raise
class _test_orphans(unittest.TestCase):
def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
class _test_as_string(unittest.TestCase): class _test_as_string(unittest.TestCase):
...@@ -78,24 +79,24 @@ class _test_as_string(unittest.TestCase): ...@@ -78,24 +79,24 @@ class _test_as_string(unittest.TestCase):
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__, node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
", ".join(argstrings)) ", ".join(argstrings))
def test_0(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"] assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"]
def test_1(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"] assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"]
def test_2(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"] assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"]
def test_3(self): def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
...@@ -105,24 +106,24 @@ class _test_as_string(unittest.TestCase): ...@@ -105,24 +106,24 @@ class _test_as_string(unittest.TestCase):
class _test_clone(unittest.TestCase): class _test_clone(unittest.TestCase):
def test_0(self): def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs) new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"] assert as_string([r1, r2], new) == ["MyOp(1, 2)"]
def test_1(self): def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
new = clone([r1, r2, r5], op2.outputs) new = clone([r1, r2, r5], op2.outputs)
assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0] assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0] # the new output is like the old one but not the same object
assert op2 is not new[0].owner assert op2 is not new[0].owner # the new output has a new owner
assert new[0].owner.inputs[1] is r5 assert new[0].owner.inputs[1] is r5 # the inputs are not copied
assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0] assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0] # check that we copied deeper too
def test_2(self): def test_not_destructive(self):
"Checks that manipulating a cloned graph leaves the original unchanged." # Checks that manipulating a cloned graph leaves the original unchanged.
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(MyOp(r1, r2).outputs[0], r5) op = MyOp(MyOp(r1, r2).outputs[0], r5)
new = clone([r1, r2, r5], op.outputs) new = clone([r1, r2, r5], op.outputs)
......
...@@ -64,8 +64,6 @@ def inputs(): ...@@ -64,8 +64,6 @@ def inputs():
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
def perform_linker(env): def perform_linker(env):
...@@ -75,26 +73,27 @@ def perform_linker(env): ...@@ -75,26 +73,27 @@ def perform_linker(env):
class _test_PerformLinker(unittest.TestCase): class _test_PerformLinker(unittest.TestCase):
def test_0(self): def test_thunk_inplace(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True) fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True)
fn() fn()
assert e.data == 1.5 assert e.data == 1.5
def test_1(self): def test_thunk_not_inplace(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False) fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
fn() fn()
assert o[0].data == 1.5
assert e.data != 1.5 assert e.data != 1.5
def test_2(self): def test_function(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn = perform_linker(env([x, y, z], [e])).make_function() fn = perform_linker(env([x, y, z], [e])).make_function()
assert fn(1.0, 2.0, 3.0) == 1.5 assert fn(1.0, 2.0, 3.0) == 1.5
assert e.data != 1.5 assert e.data != 1.5 # not inplace
def test_input_output_same(self): def test_input_output_same(self):
x, y, z = inputs() x, y, z = inputs()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import unittest import unittest
from copy import copy from copy import copy
from op import * from op import *
from result import ResultBase #, BrokenLinkError from result import ResultBase
class MyResult(ResultBase): class MyResult(ResultBase):
...@@ -34,27 +34,6 @@ class MyOp(Op): ...@@ -34,27 +34,6 @@ class MyOp(Op):
self.inputs = inputs self.inputs = inputs
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# if self.outputs is None:
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
# return True
# else:
# old_thingy = self.outputs[0].thingy
# new_thingy = sum([input.thingy for input in self.inputs])
# self.outputs[0].thingy = new_thingy
# return old_thingy != new_thingy
# class MyOp(Op):
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class _test_Op(unittest.TestCase): class _test_Op(unittest.TestCase):
...@@ -75,100 +54,6 @@ class _test_Op(unittest.TestCase): ...@@ -75,100 +54,6 @@ class _test_Op(unittest.TestCase):
else: else:
raise Exception("Expected an exception") raise Exception("Expected an exception")
# # Setting inputs and outputs
# def test_set_inputs(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.outputs[0]
# op.inputs = MyResult(4), MyResult(5)
# op.validate_update()
# assert op.outputs == [MyResult(9)] # check if the output changed to what I expect
# # assert r3.data is op.outputs[0].data # check if the data was properly transferred by set_output
# def test_set_bad_inputs(self):
# op = MyOp(MyResult(1), MyResult(2))
# try:
# op.inputs = MyResult(4), ResultBase()
# op.validate_update()
# except Exception, e:
# assert str(e) == "Error 1"
# else:
# raise Exception("Expected an exception")
# def test_set_outputs(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2) # here we only make one output
# try:
# op.outputs = MyResult(10), MyResult(11) # setting two outputs should fail
# except TypeError, e:
# assert str(e) == "The new outputs must be exactly as many as the previous outputs."
# else:
# raise Exception("Expected an exception")
# # Tests about broken links
# def test_create_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op.inputs = MyResult(3), MyResult(4)
# assert r3 not in op.outputs
# assert r3.replaced
# def test_cannot_copy_when_input_is_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# copy(op2)
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_get_input_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# op2.get_input(0)
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_get_inputs_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# op2.get_inputs()
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_repair_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3, MyResult(10))
# op.inputs = MyResult(3), MyResult(4)
# op2.repair()
# assert op2.outputs == [MyResult(17)]
# # Tests about string representation
# def test_create_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# assert str(op) == "MyOp(1, 2)"
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -260,7 +260,7 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -260,7 +260,7 @@ class _test_MergeOptimizer(unittest.TestCase):
class _test_ConstantFinder(unittest.TestCase): class _test_ConstantFinder(unittest.TestCase):
def test_0(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
y.data = 2 y.data = 2
z.data = 2 z.data = 2
...@@ -272,7 +272,7 @@ class _test_ConstantFinder(unittest.TestCase): ...@@ -272,7 +272,7 @@ class _test_ConstantFinder(unittest.TestCase):
assert str(g) == "[Op1(x, y, y)]" \ assert str(g) == "[Op1(x, y, y)]" \
or str(g) == "[Op1(x, z, z)]" or str(g) == "[Op1(x, z, z)]"
def test_1(self): def test_deep(self):
x, y, z = inputs() x, y, z = inputs()
y.data = 2 y.data = 2
z.data = 2 z.data = 2
...@@ -284,11 +284,11 @@ class _test_ConstantFinder(unittest.TestCase): ...@@ -284,11 +284,11 @@ class _test_ConstantFinder(unittest.TestCase):
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \ assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]" or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
def test_2(self): def test_destroyed_orphan_not_constant(self):
x, y, z = inputs() x, y, z = inputs()
y.data = 2 y.data = 2
z.data = 2 z.data = 2
e = op_d(x, op2(y, z)) e = op_d(x, op2(y, z)) # here x is destroyed by op_d
g = env([y], [e]) g = env([y], [e])
ConstantFinder().optimize(g) ConstantFinder().optimize(g)
assert not getattr(x, 'constant', False) and z.constant assert not getattr(x, 'constant', False) and z.constant
......
...@@ -36,9 +36,10 @@ class MyResult(ResultBase): ...@@ -36,9 +36,10 @@ class MyResult(ResultBase):
class _test_ResultBase(unittest.TestCase): class _test_ResultBase(unittest.TestCase):
def test_0(self): def test_trivial(self):
r = ResultBase() r = ResultBase()
def test_1(self):
def test_state(self):
r = ResultBase() r = ResultBase()
assert r.state is Empty assert r.state is Empty
......
...@@ -54,14 +54,12 @@ def inputs(): ...@@ -54,14 +54,12 @@ def inputs():
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_EquivTool(unittest.TestCase): class _test_EquivTool(unittest.TestCase):
def test_0(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
sx = sigmoid(x) sx = sigmoid(x)
e = add(sx, sigmoid(y)) e = add(sx, sigmoid(y))
...@@ -72,6 +70,35 @@ class _test_EquivTool(unittest.TestCase): ...@@ -72,6 +70,35 @@ class _test_EquivTool(unittest.TestCase):
assert isinstance(g.equiv(sx).owner, Dot) assert isinstance(g.equiv(sx).owner, Dot)
class _test_InstanceFinder(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = env([x, y, z], [e], features = [InstanceFinder])
for type, num in ((Add, 3), (Sigmoid, 3), (Dot, 2)):
if not len([x for x in g.get_instances_of(type)]) == num:
self.fail((type, num))
new_e0 = add(y, z)
assert e0.owner in g.get_instances_of(Dot)
assert new_e0.owner not in g.get_instances_of(Add)
g.replace(e0, new_e0)
assert e0.owner not in g.get_instances_of(Dot)
assert new_e0.owner in g.get_instances_of(Add)
for type, num in ((Add, 4), (Sigmoid, 3), (Dot, 1)):
if not len([x for x in g.get_instances_of(type)]) == num:
self.fail((type, num))
def test_robustness(self):
x, y, z = inputs()
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g = env([x, y, z], [e], features = [InstanceFinder])
gen = g.get_instances_of(Sigmoid) # I want to get Sigmoid instances
g.replace(e, add(x, y)) # but here I prune them all
assert len([x for x in gen]) == 0 # the generator should not yield them
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -512,7 +512,8 @@ class CLinker(Linker): ...@@ -512,7 +512,8 @@ class CLinker(Linker):
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated) # (basically, everything that the previous call to uniq eliminated)
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i] self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
return self.struct_code
def find_task(self, failure_code): def find_task(self, failure_code):
""" """
Maps a failure code to the task that is associated to it. Maps a failure code to the task that is associated to it.
......
...@@ -113,25 +113,26 @@ class PrintListener(Listener): ...@@ -113,25 +113,26 @@ class PrintListener(Listener):
print "-- moving from %s to %s" % (r, new_r) print "-- moving from %s to %s" % (r, new_r)
### UNUSED AND UNTESTED ###
class ChangeListener(Listener): # class ChangeListener(Listener):
def __init__(self, env): # def __init__(self, env):
self.change = False # self.change = False
def on_import(self, op): # def on_import(self, op):
self.change = True # self.change = True
def on_prune(self, op): # def on_prune(self, op):
self.change = True # self.change = True
def on_rewire(self, clients, r, new_r): # def on_rewire(self, clients, r, new_r):
self.change = True # self.change = True
def __call__(self, value = "get"): # def __call__(self, value = "get"):
if value == "get": # if value == "get":
return self.change # return self.change
else: # else:
self.change = value # self.change = value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论