提交 eec75e98 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -7,6 +7,8 @@ import gradient ...@@ -7,6 +7,8 @@ import gradient
from sparse import _is_dense, _is_sparse, _is_dense_result, _is_sparse_result from sparse import _is_dense, _is_sparse, _is_dense_result, _is_sparse_result
from sparse import _mtypes, _mtype_to_str from sparse import _mtypes, _mtype_to_str
import random
class T_transpose(unittest.TestCase): class T_transpose(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(44) numpy.random.seed(44)
......
差异被折叠。
...@@ -25,37 +25,37 @@ class _test_inplace_opt(unittest.TestCase): ...@@ -25,37 +25,37 @@ class _test_inplace_opt(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = x + y + z e = x + y + z
g = Env([x, y], [e]) g = Env([x, y], [e])
assert str(g) == "[Broadcast{Add}(Broadcast{Add}(x, y), z)]" self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(x, y), z)]")
inplace_optimizer.optimize(g) inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 0}(Broadcast{Add}{0: 0}(x, y), z)]" self.failUnless(str(g) == "[Broadcast{Add}{0: 0}(Broadcast{Add}{0: 0}(x, y), z)]")
def test_multiple_uses(self): def test_multiple_uses(self):
x, y, z = inputs() x, y, z = inputs()
e0 = x + y e0 = x + y
e1 = x * y e1 = x * y
g = Env([x, y], [e0, e1]) g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}(x, y)]" self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}(x, y)]")
inplace_optimizer.optimize(g) inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 0}(x, y), Broadcast{Mul}(x, y)]" \ self.failUnless(str(g) == "[Broadcast{Add}{0: 0}(x, y), Broadcast{Mul}(x, y)]" \
or str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]" or str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]")
def test_user_inplace(self): def test_user_inplace(self):
x, y, z = inputs() x, y, z = inputs()
e0 = x + y e0 = x + y
e1 = tensor.mul_inplace(x, y) e1 = tensor.mul_inplace(x, y)
g = Env([x, y], [e0, e1]) g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]" self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]")
inplace_optimizer.optimize(g) inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]" self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]")
def test_inplace_on_second_argument(self): def test_inplace_on_second_argument(self):
x, y, z = inputs() x, y, z = inputs()
e0 = x + y e0 = x + y
e1 = tensor.mul_inplace(x, z) e1 = tensor.mul_inplace(x, z)
g = Env([x, y], [e0, e1]) g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, z)]" self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, z)]")
inplace_optimizer.optimize(g) inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]" self.failUnless(str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]")
class _test_dimshuffle_lift(unittest.TestCase): class _test_dimshuffle_lift(unittest.TestCase):
...@@ -64,23 +64,23 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -64,23 +64,23 @@ class _test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0)) e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
assert str(g) == "[DimShuffle{10}(DimShuffle{10}(x))]" self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g) lift_dimshuffle.optimize(g)
assert str(g) == "[x]" self.failUnless(str(g) == "[x]")
def test_merge2(self): def test_merge2(self):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{20x1}(DimShuffle{1x0}(x))]", str(g)) self.failUnless(str(g) == "[InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g) lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[DimShuffle{01xx}(x)]", str(g)) self.failUnless(str(g) == "[InplaceDimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self): def test_elim3(self):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{10}(DimShuffle{20x1}(DimShuffle{0x1}(x)))]", str(g)) self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g) lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[x]", str(g)) self.failUnless(str(g) == "[x]", str(g))
...@@ -88,9 +88,9 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -88,9 +88,9 @@ class _test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs([0]*1, [0]*2, [0]*3) x, y, z = inputs([0]*1, [0]*2, [0]*3)
e = x + y + z e = x + y + z
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
self.failUnless(str(g) == "[Broadcast{Add}(DimShuffle{x01}(Broadcast{Add}(DimShuffle{x0}(x), y)), z)]", str(g)) self.failUnless(str(g) == "[Broadcast{Add}(InplaceDimShuffle{x,0,1}(Broadcast{Add}(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g) lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(DimShuffle{xx0}(x), DimShuffle{x01}(y)), z)]", str(g)) self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
class _test_cliques(unittest.TestCase): class _test_cliques(unittest.TestCase):
...@@ -103,10 +103,10 @@ class _test_cliques(unittest.TestCase): ...@@ -103,10 +103,10 @@ class _test_cliques(unittest.TestCase):
e = x + y + d e = x + y + d
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
cliques = find_cliques(g) cliques = find_cliques(g)
assert len(cliques) == 2 self.failUnless(len(cliques) == 2)
(i1, o1), (i2, o2) = cliques (i1, o1), (i2, o2) = cliques
assert str(Env(i1, o1)) == "[Broadcast{Add}(Broadcast{Add}(x, y), d)]" self.failUnless(str(Env(i1, o1)) == "[Broadcast{Add}(Broadcast{Add}(x, y), d)]")
assert str(Env(i2, o2)) == "[Broadcast{Mul}(y, z)]" self.failUnless(str(Env(i2, o2)) == "[Broadcast{Mul}(y, z)]")
# print g # print g
# for i, o in find_cliques(g): # for i, o in find_cliques(g):
# print "-->", Env(i, [o]) # print "-->", Env(i, [o])
...@@ -116,8 +116,8 @@ class _test_cliques(unittest.TestCase): ...@@ -116,8 +116,8 @@ class _test_cliques(unittest.TestCase):
e = x + y + z e = x + y + z
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
lift_dimshuffle.optimize(g) lift_dimshuffle.optimize(g)
assert len(find_cliques(g, through_broadcast = True)) == 1 self.failUnless(len(find_cliques(g, through_broadcast = True)) == 1)
assert len(find_cliques(g, through_broadcast = False)) == 2 self.failUnless(len(find_cliques(g, through_broadcast = False)) == 2)
# print g # print g
# for i, o in find_cliques(g, True): # for i, o in find_cliques(g, True):
# print "-->", Env(i, [o]) # print "-->", Env(i, [o])
......
...@@ -9,6 +9,9 @@ import gof ...@@ -9,6 +9,9 @@ import gof
from gof.python25 import all from gof.python25 import all
# tensor depends on elemwise to provide definitions for several ops
# but elemwise needs to make Tensor instances, so we have these as
# placeholders and the tensor module fills them
def as_tensor(data): def as_tensor(data):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise") raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
...@@ -30,11 +33,11 @@ class DimShuffle(Op): ...@@ -30,11 +33,11 @@ class DimShuffle(Op):
""" """
Usage: DimShuffle(new_order, inplace = True) Usage: DimShuffle(new_order, inplace = True)
* new_order: a list representing the relationship between the - new_order: a list representing the relationship between the
input's dimensions and the output's dimensions. Each input's dimensions and the output's dimensions. Each
element of the list can either be an index or 'x'. element of the list can either be an index or 'x'.
* inplace: if True, the output will be a view of the input. - inplace: if True, the output will be a view of the input.
If False, the output will be a copy of the input. If False, the output will be a copy of the input.
If j = new_order[i] is an index, the output's ith dimension If j = new_order[i] is an index, the output's ith dimension
will be the input's jth dimension. will be the input's jth dimension.
...@@ -47,6 +50,7 @@ class DimShuffle(Op): ...@@ -47,6 +50,7 @@ class DimShuffle(Op):
Examples: Examples:
# t<n> represents a n-d tensor # t<n> represents a n-d tensor
DimShuffle(t0, ['x']) -> make a 0d (scalar) into a 1d vector
DimShuffle(t2, [0, 1]) -> identity DimShuffle(t2, [0, 1]) -> identity
DimShuffle(t2, [1, 0]) -> inverts the first and second dimensions DimShuffle(t2, [1, 0]) -> inverts the first and second dimensions
DimShuffle(t1, ['x', 0]) -> make a row out of a 1d vector DimShuffle(t1, ['x', 0]) -> make a row out of a 1d vector
...@@ -54,6 +58,8 @@ class DimShuffle(Op): ...@@ -54,6 +58,8 @@ class DimShuffle(Op):
DimShuffle(t3, [2, 0, 1]) -> like doing t3.transpose((2, 0, 1)) in numpy DimShuffle(t3, [2, 0, 1]) -> like doing t3.transpose((2, 0, 1)) in numpy
DimShuffle(t2, [0, 'x', 1]) -> like doing t3.reshape((t3.shape[0], 1, t3.shape[1])) in numpy DimShuffle(t2, [0, 'x', 1]) -> like doing t3.reshape((t3.shape[0], 1, t3.shape[1])) in numpy
DimShuffle(t2, [1, 'x', 0]) -> like doing t3.T.reshape((t3.shape[0], 1, t3.shape[1])) in numpy DimShuffle(t2, [1, 'x', 0]) -> like doing t3.T.reshape((t3.shape[0], 1, t3.shape[1])) in numpy
@todo: Default value for inplace should be False! Unsafe optimizations should be explicitly enabled.
""" """
def __init__(self, input_broadcastable, new_order, inplace = True): def __init__(self, input_broadcastable, new_order, inplace = True):
...@@ -113,7 +119,10 @@ class DimShuffle(Op): ...@@ -113,7 +119,10 @@ class DimShuffle(Op):
return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable) return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __str__(self): def __str__(self):
return "DimShuffle{%s}" % "".join(str(x) for x in self.new_order) if self.inplace:
return "InplaceDimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
else:
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
def perform(self, node, (input, ), (storage, )): def perform(self, node, (input, ), (storage, )):
# drop # drop
......
from collections import deque
import unittest import unittest
from graph import * from graph import *
...@@ -7,6 +7,30 @@ from op import Op ...@@ -7,6 +7,30 @@ from op import Op
from type import Type from type import Type
from graph import Result from graph import Result
def inputs(result_list):
"""
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
"""
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
if 1:
testcase = unittest.TestCase
else:
testcase = object
realtestcase = unittest.TestCase
class MyType(Type): class MyType(Type):
...@@ -18,10 +42,10 @@ class MyType(Type): ...@@ -18,10 +42,10 @@ class MyType(Type):
return isinstance(other, MyType) and other.thingy == self.thingy return isinstance(other, MyType) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return 'R%s' % str(self.thingy)
def __repr__(self): def __repr__(self):
return str(self.thingy) return 'R%s' % str(self.thingy)
def MyResult(thingy): def MyResult(thingy):
return Result(MyType(thingy), None, None) return Result(MyType(thingy), None, None)
...@@ -75,43 +99,44 @@ MyOp = MyOp() ...@@ -75,43 +99,44 @@ MyOp = MyOp()
# self.outputs = [MyResult(sum([input.thingy for input in inputs]))] # self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(unittest.TestCase): class _test_inputs(testcase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
assert inputs(node.outputs) == set([r1, r2]) assert inputs(node.outputs) == [r1, r2]
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert inputs(node2.outputs) == set([r1, r2, r5]) i = inputs(node2.outputs)
self.failUnless(i == [r1, r2, r5], i)
# def test_unreached_inputs(self): # def test_unreached_inputs(self):
# r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) # r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
# node = MyOp.make_node(r1, r2) # op = MyOp(r1, r2)
# node2 = MyOp.make_node(node.outputs[0], r5) # op2 = MyOp(op.outputs[0], r5)
# try: # try:
# # function doesn't raise if we put False instead of True # # function doesn't raise if we put False instead of True
# ro = results_and_orphans([r1, r2, node2.outputs[0]], node.outputs, True) # ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
# self.fail()
# except Exception, e: # except Exception, e:
# if e[0] is results_and_orphans.E_unreached: # if e[0] is results_and_orphans.E_unreached:
# return # return
# raise # self.fail()
class _test_orphans(unittest.TestCase): class _test_orphans(testcase):
def test_straightforward(self): def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert orphans([r1, r2], node2.outputs) == set([r5]) orph = orphans([r1, r2], node2.outputs)
self.failUnless(orph == [r5], orph)
class _test_as_string(unittest.TestCase): class _test_as_string(testcase):
leaf_formatter = lambda self, leaf: str(leaf.type) leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op, node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
...@@ -125,29 +150,31 @@ class _test_as_string(unittest.TestCase): ...@@ -125,29 +150,31 @@ class _test_as_string(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
assert self.str([r1, r2], node.outputs) == ["MyOp(1, 2)"] s = self.str([r1, r2], node.outputs)
self.failUnless(s == ["MyOp(R1, R2)"], s)
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(MyOp(1, 2), 5)"] s = self.str([r1, r2, r5], node2.outputs)
self.failUnless(s == ["MyOp(MyOp(R1, R2), R5)"], s)
def test_multiple_references(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) node2 = MyOp.make_node(node.outputs[0], node.outputs[0])
assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"] assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"]
def test_cutoff(self): def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) node2 = MyOp.make_node(node.outputs[0], node.outputs[0])
assert self.str(node.outputs, node2.outputs) == ["MyOp(3, 3)"] assert self.str(node.outputs, node2.outputs) == ["MyOp(R3, R3)"]
assert self.str(node2.inputs, node2.outputs) == ["MyOp(3, 3)"] assert self.str(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"]
class _test_clone(unittest.TestCase): class _test_clone(testcase):
leaf_formatter = lambda self, leaf: str(leaf.type) leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op, node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
...@@ -162,7 +189,7 @@ class _test_clone(unittest.TestCase): ...@@ -162,7 +189,7 @@ class _test_clone(unittest.TestCase):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
_, new = clone([r1, r2], node.outputs, False) _, new = clone([r1, r2], node.outputs, False)
assert self.str([r1, r2], new) == ["MyOp(1, 2)"] assert self.str([r1, r2], new) == ["MyOp(R1, R2)"]
def test_copy(self): def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
...@@ -181,14 +208,89 @@ class _test_clone(unittest.TestCase): ...@@ -181,14 +208,89 @@ class _test_clone(unittest.TestCase):
_, new = clone([r1, r2, r5], node.outputs, False) _, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner new_node = new[0].owner
new_node.inputs = MyResult(7), MyResult(8) new_node.inputs = MyResult(7), MyResult(8)
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
def prenode(obj):
if isinstance(obj, Result):
if obj.owner:
return [obj.owner]
if isinstance(obj, Op):
return obj.inputs
class _test_toposort(testcase):
def test0(self):
"""Test a simple graph"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r2)
o2 = MyOp(o.outputs[0], r5)
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(7, 8)"] all = general_toposort(o2.outputs, prenode)
assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(1, 2), 5)"] self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
all = io_toposort([r5], o2.outputs)
self.failUnless(all == [o, o2], all)
def test1(self):
"""Test a graph with double dependencies"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
o2 = MyOp(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
if __name__ == '__main__': def test2(self):
unittest.main() """Test a graph where the inputs have owners"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
r2b = o.outputs[0]
o2 = MyOp(r2b, r2b)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
o2 = MyOp(r2b, r5)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
def test3(self):
"""Test a graph which is not connected"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
self.failUnless(all == [o1,o0], all)
def test4(self):
"""Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
self.failUnless(all == [o1], all)
def test5(self):
"""Test when outputs have clients"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r4)
all = io_toposort([], o0.outputs)
self.failUnless(all == [o0], all)
if __name__ == '__main__':
if 1:
#run all tests
unittest.main()
elif 1:
#load some TestCase classes
suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(_test_toposort)
#run just some of them
unittest.TextTestRunner(verbosity=2).run(suite)
else:
#run just a single test
_test_toposort('test0').debug()
差异被折叠。
...@@ -35,13 +35,14 @@ class Op(object2): ...@@ -35,13 +35,14 @@ class Op(object2):
# Python implementation # # Python implementation #
######################### #########################
def impl(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
""" """
Calculate the function on the inputs and put the results in the Calculate the function on the inputs and put the results in the
output storage. output storage.
- inputs: sequence of inputs (immutable) - inputs: sequence of inputs (immutable)
- outputs: mutable list - output_storage: list of mutable 1-element lists (do not change
the length of these lists)
The output_storage list might contain data. If an element of The output_storage list might contain data. If an element of
output_storage is not None, it is guaranteed that it was produced output_storage is not None, it is guaranteed that it was produced
...@@ -50,36 +51,10 @@ class Op(object2): ...@@ -50,36 +51,10 @@ class Op(object2):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
##################### #####################
# C code generation # # C code generation #
##################### #####################
# def c_validate_update(self, inputs, outputs, sub):
# """
# Returns templated C code that checks that the inputs to this
# function can be worked on. If a failure occurs, set an
# Exception and insert "%(fail)s".
# You may use the variable names defined by c_var_names() in
# the template.
# Note: deprecated!!
# @todo: Merge this with c_code.
# """
# raise AbstractFunctionError()
# def c_validate_update_cleanup(self, inputs, outputs, sub):
# """
# Clean up things allocated by L{c_validate}().
# Note: deprecated!!
# @todo: Merge this with c_code.
# """
# raise AbstractFunctionError()
# raise AbstractFunctionError('%s.c_validate_update_cleanup ' \
# % self.__class__.__name__)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
"""Return the C implementation of an Op. """Return the C implementation of an Op.
...@@ -151,28 +126,3 @@ class PropertiedOp(Op): ...@@ -151,28 +126,3 @@ class PropertiedOp(Op):
return "%s{%s}" % (self.__class__.__name__, ", ".join("%s=%s" % (k, v) for k, v in self.__dict__.items() if k != "name")) return "%s{%s}" % (self.__class__.__name__, ", ".join("%s=%s" % (k, v) for k, v in self.__dict__.items() if k != "name"))
# #TODO: consider adding a flag to the base class that toggles this behaviour
# class GuardedOp(Op):
# """An Op that disallows input properties to change after construction"""
# def set_input(self, i, new):
# old = self._inputs[i]
# if old is new:
# return
# try:
# if not old.same_properties(new):
# raise TypeError("The new input must have the same properties as the previous one.")
# except AbstractFunctionError:
# pass
# Op.set_input(self, i, new)
# def set_inputs(self, new):
# if not hasattr(self, '_inputs') or self_inputs is None:
# Op.set_inputs(self, new)
# else:
# if not len(new) == len(self._inputs):
# raise TypeError("The new inputs are not as many as the previous ones.")
# for i, new in enumerate(new):
# self.set_input(i, new)
...@@ -38,6 +38,31 @@ class scratchpad: ...@@ -38,6 +38,31 @@ class scratchpad:
def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call.
Use it like this:
@deprecated('myfile', 'do something different...')
def fn_name(...)
...
And it will print
WARNING myfile.fn_name deprecated. do something different...
"""
def _deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'WARNING: %s.%s deprecated. %s'\
% (filename, f.__name__, msg)
printme[0] = False
return f(*args, **kwargs)
return g
return _deprecated
def uniq(seq): def uniq(seq):
#TODO: consider building a set out of seq so that the if condition is constant time -JB #TODO: consider building a set out of seq so that the if condition is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i] return [x for i, x in enumerate(seq) if seq.index(x) == i]
...@@ -55,6 +80,7 @@ def difference(seq1, seq2): ...@@ -55,6 +80,7 @@ def difference(seq1, seq2):
# -> use O(len(seq1) * len(seq2)) algo # -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2] return [x for x in seq1 if x not in seq2]
def partition(f, seq): def partition(f, seq):
seqt = [] seqt = []
seqf = [] seqf = []
......
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论