提交 5301d18d authored 作者: James Bergstra's avatar James Bergstra

merged

...@@ -124,9 +124,12 @@ class T_Function(unittest.TestCase): ...@@ -124,9 +124,12 @@ class T_Function(unittest.TestCase):
def test_orphans(self): def test_orphans(self):
gi, go = graph1() gi, go = graph1()
opt = None opt = None
p0 = Function(gi[0:0], go) p0 = Function(gi[0:0], go, keep_locals=True)
self.failUnless(p0() == 1.0) self.failUnless(p0.orphans == go, p0.orphans)
vp0 = p0()
self.failUnless(vp0 == 1.0, vp0)
p3 = Function(gi,go) p3 = Function(gi,go)
p2 = Function(gi[0:2], go) p2 = Function(gi[0:2], go)
......
from collections import deque
import unittest import unittest
from graph import * from graph import *
from op import Op from op import Op
from result import Result from result import Result
def inputs(result_list):
"""
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
"""
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
if 1:
testcase = unittest.TestCase
else:
testcase = object
realtestcase = unittest.TestCase
class MyResult(Result): class MyResult(Result):
...@@ -21,10 +45,10 @@ class MyResult(Result): ...@@ -21,10 +45,10 @@ class MyResult(Result):
return isinstance(other, MyResult) and other.thingy == self.thingy return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return 'R%s' % str(self.thingy)
def __repr__(self): def __repr__(self):
return str(self.thingy) return 'R%s' % str(self.thingy)
class MyOp(Op): class MyOp(Op):
...@@ -37,18 +61,19 @@ class MyOp(Op): ...@@ -37,18 +61,19 @@ class MyOp(Op):
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(unittest.TestCase): class _test_inputs(testcase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2]) assert inputs(op.outputs) == [r1, r2]
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5]) i = inputs(op2.outputs)
self.failUnless(i == [r1, r2, r5], i)
def test_unreached_inputs(self): def test_unreached_inputs(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
...@@ -57,23 +82,26 @@ class _test_inputs(unittest.TestCase): ...@@ -57,23 +82,26 @@ class _test_inputs(unittest.TestCase):
try: try:
# function doesn't raise if we put False instead of True # function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True) ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
self.fail()
except Exception, e: except Exception, e:
if e[0] is results_and_orphans.E_unreached: if e[0] is results_and_orphans.E_unreached:
return return
raise self.fail()
def test_uz(self):
pass
class _test_orphans(unittest.TestCase):
class _test_orphans(testcase):
def test_straightforward(self): def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5]) orph = orphans([r1, r2], op2.outputs)
self.failUnless(orph == [r5], orph)
class _test_as_string(unittest.TestCase): class _test_as_string(testcase):
leaf_formatter = str leaf_formatter = str
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__, node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
...@@ -82,35 +110,37 @@ class _test_as_string(unittest.TestCase): ...@@ -82,35 +110,37 @@ class _test_as_string(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"] s = as_string([r1, r2], op.outputs)
self.failUnless(s == ["MyOp(R1, R2)"], s)
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"] s = as_string([r1, r2, r5], op2.outputs)
self.failUnless(s == ["MyOp(MyOp(R1, R2), R5)"], s)
def test_multiple_references(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"] assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"]
def test_cutoff(self): def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string(op.outputs, op2.outputs) == ["MyOp(3, 3)"] assert as_string(op.outputs, op2.outputs) == ["MyOp(R3, R3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(3, 3)"] assert as_string(op2.inputs, op2.outputs) == ["MyOp(R3, R3)"]
class _test_clone(unittest.TestCase): class _test_clone(testcase):
def test_accurate(self): def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs) new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"] assert as_string([r1, r2], new) == ["MyOp(R1, R2)"]
def test_copy(self): def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
...@@ -130,13 +160,91 @@ class _test_clone(unittest.TestCase): ...@@ -130,13 +160,91 @@ class _test_clone(unittest.TestCase):
new_op = new[0].owner new_op = new[0].owner
new_op.inputs = MyResult(7), MyResult(8) new_op.inputs = MyResult(7), MyResult(8)
assert as_string(inputs(new_op.outputs), new_op.outputs) == ["MyOp(7, 8)"] s = as_string(inputs(new_op.outputs), new_op.outputs)
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(1, 2), 5)"] self.failUnless( s == ["MyOp(R7, R8)"], s)
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
def prenode(obj):
if isinstance(obj, Result):
if obj.owner:
return [obj.owner]
if isinstance(obj, Op):
return obj.inputs
class _test_toposort(testcase):
def test0(self):
"""Test a simple graph"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r2)
o2 = MyOp(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
all = io_toposort([r5], o2.outputs)
self.failUnless(all == [o, o2], all)
def test1(self):
"""Test a graph with double dependencies"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
o2 = MyOp(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
def test2(self):
"""Test a graph where the inputs have owners"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
r2b = o.outputs[0]
o2 = MyOp(r2b, r2b)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
o2 = MyOp(r2b, r5)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
def test3(self):
"""Test a graph which is not connected"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
self.failUnless(all == [o1,o0], all)
def test4(self):
"""Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
self.failUnless(all == [o1], all)
def test5(self):
"""Test when outputs have clients"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r4)
all = io_toposort([], o0.outputs)
self.failUnless(all == [o0], all)
if __name__ == '__main__': if __name__ == '__main__':
if 1:
#run all tests
unittest.main() unittest.main()
elif 1:
#load some TestCase classes
suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(_test_toposort)
#run just some of them
unittest.TextTestRunner(verbosity=2).run(suite)
else:
#run just a single test
_test_toposort('test0').debug()
差异被折叠。
...@@ -6,7 +6,6 @@ compatible with gof's graph manipulation routines. ...@@ -6,7 +6,6 @@ compatible with gof's graph manipulation routines.
import utils import utils
from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError
import graph
from copy import copy from copy import copy
...@@ -170,9 +169,9 @@ class Op(object): ...@@ -170,9 +169,9 @@ class Op(object):
# String representation # String representation
# #
if 0:
def __str__(self): def __str__(self):
return graph.op_as_string(self.inputs, self) return graph.op_as_string(self.inputs, self)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
......
...@@ -19,6 +19,31 @@ class AbstractFunctionError(Exception): ...@@ -19,6 +19,31 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class. function has been left out of an implementation class.
""" """
def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call.
Use it like this:
@deprecated('myfile', 'do something different...')
def fn_name(...)
...
And it will print
WARNING myfile.fn_name deprecated. do something different...
"""
def _deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'WARNING: %s.%s deprecated. %s'\
% (filename, f.__name__, msg)
printme[0] = False
return f(*args, **kwargs)
return g
return _deprecated
def uniq(seq): def uniq(seq):
#TODO: consider building a set out of seq so that the if condition is constant time -JB #TODO: consider building a set out of seq so that the if condition is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i] return [x for i, x in enumerate(seq) if seq.index(x) == i]
...@@ -36,6 +61,7 @@ def difference(seq1, seq2): ...@@ -36,6 +61,7 @@ def difference(seq1, seq2):
# -> use O(len(seq1) * len(seq2)) algo # -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2] return [x for x in seq1 if x not in seq2]
def partition(f, seq): def partition(f, seq):
seqt = [] seqt = []
seqf = [] seqf = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论