提交 bcb1a525 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

documented and added a lot of tests

上级 e2db78da
......@@ -169,6 +169,14 @@ class _test_CLinker(unittest.TestCase):
self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_dups_inner(self):
# Testing that duplicates are allowed inside the graph
x, y, z = inputs()
e = add(mul(y, y), add(x, z))
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
class _test_OpWiseCLinker(unittest.TestCase):
......@@ -180,9 +188,4 @@ class _test_OpWiseCLinker(unittest.TestCase):
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
if __name__ == '__main__':
# unittest.main()
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
fn(2.0, 0.0, 2.0)
unittest.main()
差异被折叠。
......@@ -39,37 +39,38 @@ class MyOp(Op):
class _test_inputs(unittest.TestCase):
def test_0(self):
def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2])
def test_1(self):
def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5])
class _test_orphans(unittest.TestCase):
def test_0(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
def test_1(self):
def test_unreached_inputs(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
try:
# function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
self.fail()
except Exception, e:
if e[0] is results_and_orphans.E_unreached:
return
raise
class _test_orphans(unittest.TestCase):
def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
class _test_as_string(unittest.TestCase):
......@@ -78,24 +79,24 @@ class _test_as_string(unittest.TestCase):
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
", ".join(argstrings))
def test_0(self):
def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"]
def test_1(self):
def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"]
def test_2(self):
def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"]
def test_3(self):
def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0])
......@@ -105,24 +106,24 @@ class _test_as_string(unittest.TestCase):
class _test_clone(unittest.TestCase):
def test_0(self):
def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"]
def test_1(self):
def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
new = clone([r1, r2, r5], op2.outputs)
assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0]
assert op2 is not new[0].owner
assert new[0].owner.inputs[1] is r5
assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0]
assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0] # the new output is like the old one but not the same object
assert op2 is not new[0].owner # the new output has a new owner
assert new[0].owner.inputs[1] is r5 # the inputs are not copied
assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0] # check that we copied deeper too
def test_2(self):
"Checks that manipulating a cloned graph leaves the original unchanged."
def test_not_destructive(self):
# Checks that manipulating a cloned graph leaves the original unchanged.
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(MyOp(r1, r2).outputs[0], r5)
new = clone([r1, r2, r5], op.outputs)
......
......@@ -64,8 +64,6 @@ def inputs():
return x, y, z
def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate)
def perform_linker(env):
......@@ -75,26 +73,27 @@ def perform_linker(env):
class _test_PerformLinker(unittest.TestCase):
def test_0(self):
def test_thunk_inplace(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True)
fn()
assert e.data == 1.5
def test_1(self):
def test_thunk_not_inplace(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
fn()
assert o[0].data == 1.5
assert e.data != 1.5
def test_2(self):
def test_function(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn = perform_linker(env([x, y, z], [e])).make_function()
assert fn(1.0, 2.0, 3.0) == 1.5
assert e.data != 1.5
assert e.data != 1.5 # not inplace
def test_input_output_same(self):
x, y, z = inputs()
......
......@@ -2,7 +2,7 @@
import unittest
from copy import copy
from op import *
from result import ResultBase #, BrokenLinkError
from result import ResultBase
class MyResult(ResultBase):
......@@ -34,27 +34,6 @@ class MyOp(Op):
self.inputs = inputs
self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# if self.outputs is None:
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
# return True
# else:
# old_thingy = self.outputs[0].thingy
# new_thingy = sum([input.thingy for input in self.inputs])
# self.outputs[0].thingy = new_thingy
# return old_thingy != new_thingy
# class MyOp(Op):
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class _test_Op(unittest.TestCase):
......@@ -75,100 +54,6 @@ class _test_Op(unittest.TestCase):
else:
raise Exception("Expected an exception")
# # Setting inputs and outputs
# def test_set_inputs(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.outputs[0]
# op.inputs = MyResult(4), MyResult(5)
# op.validate_update()
# assert op.outputs == [MyResult(9)] # check if the output changed to what I expect
# # assert r3.data is op.outputs[0].data # check if the data was properly transferred by set_output
# def test_set_bad_inputs(self):
# op = MyOp(MyResult(1), MyResult(2))
# try:
# op.inputs = MyResult(4), ResultBase()
# op.validate_update()
# except Exception, e:
# assert str(e) == "Error 1"
# else:
# raise Exception("Expected an exception")
# def test_set_outputs(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2) # here we only make one output
# try:
# op.outputs = MyResult(10), MyResult(11) # setting two outputs should fail
# except TypeError, e:
# assert str(e) == "The new outputs must be exactly as many as the previous outputs."
# else:
# raise Exception("Expected an exception")
# # Tests about broken links
# def test_create_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op.inputs = MyResult(3), MyResult(4)
# assert r3 not in op.outputs
# assert r3.replaced
# def test_cannot_copy_when_input_is_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# copy(op2)
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_get_input_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# op2.get_input(0)
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_get_inputs_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# op2.get_inputs()
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_repair_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3, MyResult(10))
# op.inputs = MyResult(3), MyResult(4)
# op2.repair()
# assert op2.outputs == [MyResult(17)]
# # Tests about string representation
# def test_create_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# assert str(op) == "MyOp(1, 2)"
if __name__ == '__main__':
......
......@@ -260,7 +260,7 @@ class _test_MergeOptimizer(unittest.TestCase):
class _test_ConstantFinder(unittest.TestCase):
def test_0(self):
def test_straightforward(self):
x, y, z = inputs()
y.data = 2
z.data = 2
......@@ -272,7 +272,7 @@ class _test_ConstantFinder(unittest.TestCase):
assert str(g) == "[Op1(x, y, y)]" \
or str(g) == "[Op1(x, z, z)]"
def test_1(self):
def test_deep(self):
x, y, z = inputs()
y.data = 2
z.data = 2
......@@ -284,11 +284,11 @@ class _test_ConstantFinder(unittest.TestCase):
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
def test_2(self):
def test_destroyed_orphan_not_constant(self):
x, y, z = inputs()
y.data = 2
z.data = 2
e = op_d(x, op2(y, z))
e = op_d(x, op2(y, z)) # here x is destroyed by op_d
g = env([y], [e])
ConstantFinder().optimize(g)
assert not getattr(x, 'constant', False) and z.constant
......
......@@ -36,9 +36,10 @@ class MyResult(ResultBase):
class _test_ResultBase(unittest.TestCase):
def test_0(self):
def test_trivial(self):
r = ResultBase()
def test_1(self):
def test_state(self):
r = ResultBase()
assert r.state is Empty
......
......@@ -54,14 +54,12 @@ def inputs():
return x, y, z
def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_EquivTool(unittest.TestCase):
def test_0(self):
def test_straightforward(self):
x, y, z = inputs()
sx = sigmoid(x)
e = add(sx, sigmoid(y))
......@@ -72,6 +70,35 @@ class _test_EquivTool(unittest.TestCase):
assert isinstance(g.equiv(sx).owner, Dot)
class _test_InstanceFinder(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = env([x, y, z], [e], features = [InstanceFinder])
for type, num in ((Add, 3), (Sigmoid, 3), (Dot, 2)):
if not len([x for x in g.get_instances_of(type)]) == num:
self.fail((type, num))
new_e0 = add(y, z)
assert e0.owner in g.get_instances_of(Dot)
assert new_e0.owner not in g.get_instances_of(Add)
g.replace(e0, new_e0)
assert e0.owner not in g.get_instances_of(Dot)
assert new_e0.owner in g.get_instances_of(Add)
for type, num in ((Add, 4), (Sigmoid, 3), (Dot, 1)):
if not len([x for x in g.get_instances_of(type)]) == num:
self.fail((type, num))
def test_robustness(self):
x, y, z = inputs()
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g = env([x, y, z], [e], features = [InstanceFinder])
gen = g.get_instances_of(Sigmoid) # I want to get Sigmoid instances
g.replace(e, add(x, y)) # but here I prune them all
assert len([x for x in gen]) == 0 # the generator should not yield them
if __name__ == '__main__':
unittest.main()
......
......@@ -512,7 +512,8 @@ class CLinker(Linker):
# List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated)
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
return self.struct_code
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
......
......@@ -113,25 +113,26 @@ class PrintListener(Listener):
print "-- moving from %s to %s" % (r, new_r)
### UNUSED AND UNTESTED ###
class ChangeListener(Listener):
# class ChangeListener(Listener):
def __init__(self, env):
self.change = False
# def __init__(self, env):
# self.change = False
def on_import(self, op):
self.change = True
# def on_import(self, op):
# self.change = True
def on_prune(self, op):
self.change = True
# def on_prune(self, op):
# self.change = True
def on_rewire(self, clients, r, new_r):
self.change = True
# def on_rewire(self, clients, r, new_r):
# self.change = True
def __call__(self, value = "get"):
if value == "get":
return self.change
else:
self.change = value
# def __call__(self, value = "get"):
# if value == "get":
# return self.change
# else:
# self.change = value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论