提交 bcb1a525 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

documented and added a lot of tests

上级 e2db78da
...@@ -169,6 +169,14 @@ class _test_CLinker(unittest.TestCase): ...@@ -169,6 +169,14 @@ class _test_CLinker(unittest.TestCase):
self.failUnless(fn(2.0, 2.0) == 4) self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
def test_dups_inner(self):
# Testing that duplicates are allowed inside the graph
x, y, z = inputs()
e = add(mul(y, y), add(x, z))
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
class _test_OpWiseCLinker(unittest.TestCase): class _test_OpWiseCLinker(unittest.TestCase):
...@@ -180,9 +188,4 @@ class _test_OpWiseCLinker(unittest.TestCase): ...@@ -180,9 +188,4 @@ class _test_OpWiseCLinker(unittest.TestCase):
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() unittest.main()
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
fn(2.0, 0.0, 2.0)
差异被折叠。
...@@ -39,31 +39,23 @@ class MyOp(Op): ...@@ -39,31 +39,23 @@ class MyOp(Op):
class _test_inputs(unittest.TestCase): class _test_inputs(unittest.TestCase):
def test_0(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2]) assert inputs(op.outputs) == set([r1, r2])
def test_1(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5]) assert inputs(op2.outputs) == set([r1, r2, r5])
def test_unreached_inputs(self):
class _test_orphans(unittest.TestCase):
def test_0(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
try: try:
# function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True) ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
self.fail() self.fail()
except Exception, e: except Exception, e:
...@@ -72,30 +64,39 @@ class _test_orphans(unittest.TestCase): ...@@ -72,30 +64,39 @@ class _test_orphans(unittest.TestCase):
raise raise
class _test_orphans(unittest.TestCase):
def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
class _test_as_string(unittest.TestCase): class _test_as_string(unittest.TestCase):
leaf_formatter = str leaf_formatter = str
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__, node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
", ".join(argstrings)) ", ".join(argstrings))
def test_0(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"] assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"]
def test_1(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"] assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"]
def test_2(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"] assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"]
def test_3(self): def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
...@@ -105,24 +106,24 @@ class _test_as_string(unittest.TestCase): ...@@ -105,24 +106,24 @@ class _test_as_string(unittest.TestCase):
class _test_clone(unittest.TestCase): class _test_clone(unittest.TestCase):
def test_0(self): def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs) new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"] assert as_string([r1, r2], new) == ["MyOp(1, 2)"]
def test_1(self): def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
new = clone([r1, r2, r5], op2.outputs) new = clone([r1, r2, r5], op2.outputs)
assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0] assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0] # the new output is like the old one but not the same object
assert op2 is not new[0].owner assert op2 is not new[0].owner # the new output has a new owner
assert new[0].owner.inputs[1] is r5 assert new[0].owner.inputs[1] is r5 # the inputs are not copied
assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0] assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0] # check that we copied deeper too
def test_2(self): def test_not_destructive(self):
"Checks that manipulating a cloned graph leaves the original unchanged." # Checks that manipulating a cloned graph leaves the original unchanged.
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(MyOp(r1, r2).outputs[0], r5) op = MyOp(MyOp(r1, r2).outputs[0], r5)
new = clone([r1, r2, r5], op.outputs) new = clone([r1, r2, r5], op.outputs)
......
...@@ -64,8 +64,6 @@ def inputs(): ...@@ -64,8 +64,6 @@ def inputs():
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
def perform_linker(env): def perform_linker(env):
...@@ -75,26 +73,27 @@ def perform_linker(env): ...@@ -75,26 +73,27 @@ def perform_linker(env):
class _test_PerformLinker(unittest.TestCase): class _test_PerformLinker(unittest.TestCase):
def test_0(self): def test_thunk_inplace(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True) fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True)
fn() fn()
assert e.data == 1.5 assert e.data == 1.5
def test_1(self): def test_thunk_not_inplace(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False) fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
fn() fn()
assert o[0].data == 1.5
assert e.data != 1.5 assert e.data != 1.5
def test_2(self): def test_function(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn = perform_linker(env([x, y, z], [e])).make_function() fn = perform_linker(env([x, y, z], [e])).make_function()
assert fn(1.0, 2.0, 3.0) == 1.5 assert fn(1.0, 2.0, 3.0) == 1.5
assert e.data != 1.5 assert e.data != 1.5 # not inplace
def test_input_output_same(self): def test_input_output_same(self):
x, y, z = inputs() x, y, z = inputs()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import unittest import unittest
from copy import copy from copy import copy
from op import * from op import *
from result import ResultBase #, BrokenLinkError from result import ResultBase
class MyResult(ResultBase): class MyResult(ResultBase):
...@@ -34,27 +34,6 @@ class MyOp(Op): ...@@ -34,27 +34,6 @@ class MyOp(Op):
self.inputs = inputs self.inputs = inputs
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# if self.outputs is None:
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
# return True
# else:
# old_thingy = self.outputs[0].thingy
# new_thingy = sum([input.thingy for input in self.inputs])
# self.outputs[0].thingy = new_thingy
# return old_thingy != new_thingy
# class MyOp(Op):
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class _test_Op(unittest.TestCase): class _test_Op(unittest.TestCase):
...@@ -75,100 +54,6 @@ class _test_Op(unittest.TestCase): ...@@ -75,100 +54,6 @@ class _test_Op(unittest.TestCase):
else: else:
raise Exception("Expected an exception") raise Exception("Expected an exception")
# # Setting inputs and outputs
# def test_set_inputs(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.outputs[0]
# op.inputs = MyResult(4), MyResult(5)
# op.validate_update()
# assert op.outputs == [MyResult(9)] # check if the output changed to what I expect
# # assert r3.data is op.outputs[0].data # check if the data was properly transferred by set_output
# def test_set_bad_inputs(self):
# op = MyOp(MyResult(1), MyResult(2))
# try:
# op.inputs = MyResult(4), ResultBase()
# op.validate_update()
# except Exception, e:
# assert str(e) == "Error 1"
# else:
# raise Exception("Expected an exception")
# def test_set_outputs(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2) # here we only make one output
# try:
# op.outputs = MyResult(10), MyResult(11) # setting two outputs should fail
# except TypeError, e:
# assert str(e) == "The new outputs must be exactly as many as the previous outputs."
# else:
# raise Exception("Expected an exception")
# # Tests about broken links
# def test_create_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op.inputs = MyResult(3), MyResult(4)
# assert r3 not in op.outputs
# assert r3.replaced
# def test_cannot_copy_when_input_is_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# copy(op2)
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_get_input_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# op2.get_input(0)
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_get_inputs_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# op2.get_inputs()
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_repair_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3, MyResult(10))
# op.inputs = MyResult(3), MyResult(4)
# op2.repair()
# assert op2.outputs == [MyResult(17)]
# # Tests about string representation
# def test_create_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# assert str(op) == "MyOp(1, 2)"
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -260,7 +260,7 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -260,7 +260,7 @@ class _test_MergeOptimizer(unittest.TestCase):
class _test_ConstantFinder(unittest.TestCase): class _test_ConstantFinder(unittest.TestCase):
def test_0(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
y.data = 2 y.data = 2
z.data = 2 z.data = 2
...@@ -272,7 +272,7 @@ class _test_ConstantFinder(unittest.TestCase): ...@@ -272,7 +272,7 @@ class _test_ConstantFinder(unittest.TestCase):
assert str(g) == "[Op1(x, y, y)]" \ assert str(g) == "[Op1(x, y, y)]" \
or str(g) == "[Op1(x, z, z)]" or str(g) == "[Op1(x, z, z)]"
def test_1(self): def test_deep(self):
x, y, z = inputs() x, y, z = inputs()
y.data = 2 y.data = 2
z.data = 2 z.data = 2
...@@ -284,11 +284,11 @@ class _test_ConstantFinder(unittest.TestCase): ...@@ -284,11 +284,11 @@ class _test_ConstantFinder(unittest.TestCase):
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \ assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]" or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
def test_2(self): def test_destroyed_orphan_not_constant(self):
x, y, z = inputs() x, y, z = inputs()
y.data = 2 y.data = 2
z.data = 2 z.data = 2
e = op_d(x, op2(y, z)) e = op_d(x, op2(y, z)) # here x is destroyed by op_d
g = env([y], [e]) g = env([y], [e])
ConstantFinder().optimize(g) ConstantFinder().optimize(g)
assert not getattr(x, 'constant', False) and z.constant assert not getattr(x, 'constant', False) and z.constant
......
...@@ -36,9 +36,10 @@ class MyResult(ResultBase): ...@@ -36,9 +36,10 @@ class MyResult(ResultBase):
class _test_ResultBase(unittest.TestCase): class _test_ResultBase(unittest.TestCase):
def test_0(self): def test_trivial(self):
r = ResultBase() r = ResultBase()
def test_1(self):
def test_state(self):
r = ResultBase() r = ResultBase()
assert r.state is Empty assert r.state is Empty
......
...@@ -54,14 +54,12 @@ def inputs(): ...@@ -54,14 +54,12 @@ def inputs():
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_EquivTool(unittest.TestCase): class _test_EquivTool(unittest.TestCase):
def test_0(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
sx = sigmoid(x) sx = sigmoid(x)
e = add(sx, sigmoid(y)) e = add(sx, sigmoid(y))
...@@ -72,6 +70,35 @@ class _test_EquivTool(unittest.TestCase): ...@@ -72,6 +70,35 @@ class _test_EquivTool(unittest.TestCase):
assert isinstance(g.equiv(sx).owner, Dot) assert isinstance(g.equiv(sx).owner, Dot)
class _test_InstanceFinder(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = env([x, y, z], [e], features = [InstanceFinder])
for type, num in ((Add, 3), (Sigmoid, 3), (Dot, 2)):
if not len([x for x in g.get_instances_of(type)]) == num:
self.fail((type, num))
new_e0 = add(y, z)
assert e0.owner in g.get_instances_of(Dot)
assert new_e0.owner not in g.get_instances_of(Add)
g.replace(e0, new_e0)
assert e0.owner not in g.get_instances_of(Dot)
assert new_e0.owner in g.get_instances_of(Add)
for type, num in ((Add, 4), (Sigmoid, 3), (Dot, 1)):
if not len([x for x in g.get_instances_of(type)]) == num:
self.fail((type, num))
def test_robustness(self):
x, y, z = inputs()
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g = env([x, y, z], [e], features = [InstanceFinder])
gen = g.get_instances_of(Sigmoid) # I want to get Sigmoid instances
g.replace(e, add(x, y)) # but here I prune them all
assert len([x for x in gen]) == 0 # the generator should not yield them
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -512,6 +512,7 @@ class CLinker(Linker): ...@@ -512,6 +512,7 @@ class CLinker(Linker):
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated) # (basically, everything that the previous call to uniq eliminated)
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i] self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
return self.struct_code
def find_task(self, failure_code): def find_task(self, failure_code):
""" """
......
...@@ -113,25 +113,26 @@ class PrintListener(Listener): ...@@ -113,25 +113,26 @@ class PrintListener(Listener):
print "-- moving from %s to %s" % (r, new_r) print "-- moving from %s to %s" % (r, new_r)
### UNUSED AND UNTESTED ###
class ChangeListener(Listener): # class ChangeListener(Listener):
def __init__(self, env): # def __init__(self, env):
self.change = False # self.change = False
def on_import(self, op): # def on_import(self, op):
self.change = True # self.change = True
def on_prune(self, op): # def on_prune(self, op):
self.change = True # self.change = True
def on_rewire(self, clients, r, new_r): # def on_rewire(self, clients, r, new_r):
self.change = True # self.change = True
def __call__(self, value = "get"): # def __call__(self, value = "get"):
if value == "get": # if value == "get":
return self.change # return self.change
else: # else:
self.change = value # self.change = value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论