提交 80ac390e authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed linker interface (accept method)

上级 e284625c
...@@ -25,11 +25,11 @@ class _test_DimShuffle(unittest.TestCase): ...@@ -25,11 +25,11 @@ class _test_DimShuffle(unittest.TestCase):
ib = [(entry == 1) for entry in xsh] ib = [(entry == 1) for entry in xsh]
x = Tensor('float64', ib)('x') x = Tensor('float64', ib)('x')
e = DimShuffle(ib, shuffle)(x) e = DimShuffle(ib, shuffle)(x)
f = linker(Env([x], [e])).make_function() f = copy(linker).accept(Env([x], [e])).make_function()
assert f(numpy.ones(xsh)).shape == zsh assert f(numpy.ones(xsh)).shape == zsh
def test_perform(self): def test_perform(self):
self.with_linker(gof.PerformLinker) self.with_linker(gof.PerformLinker())
class _test_Broadcast(unittest.TestCase): class _test_Broadcast(unittest.TestCase):
...@@ -47,7 +47,7 @@ class _test_Broadcast(unittest.TestCase): ...@@ -47,7 +47,7 @@ class _test_Broadcast(unittest.TestCase):
x = Tensor('float64', [(entry == 1) for entry in xsh])('x') x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
y = Tensor('float64', [(entry == 1) for entry in ysh])('y') y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(add)(x, y) e = Elemwise(add)(x, y)
f = linker(Env([x, y], [e])).make_function() f = copy(linker).accept(Env([x, y], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh)) yv = numpy.asarray(numpy.random.rand(*ysh))
zv = xv + yv zv = xv + yv
...@@ -66,7 +66,7 @@ class _test_Broadcast(unittest.TestCase): ...@@ -66,7 +66,7 @@ class _test_Broadcast(unittest.TestCase):
x = Tensor('float64', [(entry == 1) for entry in xsh])('x') x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
y = Tensor('float64', [(entry == 1) for entry in ysh])('y') y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(Add(transfer_type(0)), {0:0})(x, y) e = Elemwise(Add(transfer_type(0)), {0:0})(x, y)
f = linker(Env([x, y], [e])).make_function() f = copy(linker).accept(Env([x, y], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh)) yv = numpy.asarray(numpy.random.rand(*ysh))
zv = xv + yv zv = xv + yv
...@@ -76,22 +76,22 @@ class _test_Broadcast(unittest.TestCase): ...@@ -76,22 +76,22 @@ class _test_Broadcast(unittest.TestCase):
self.failUnless((xv == zv).all()) self.failUnless((xv == zv).all())
def test_perform(self): def test_perform(self):
self.with_linker(gof.PerformLinker) self.with_linker(gof.PerformLinker())
def test_c(self): def test_c(self):
self.with_linker(gof.CLinker) self.with_linker(gof.CLinker())
def test_perform_inplace(self): def test_perform_inplace(self):
self.with_linker_inplace(gof.PerformLinker) self.with_linker_inplace(gof.PerformLinker())
def test_c_inplace(self): def test_c_inplace(self):
self.with_linker_inplace(gof.CLinker) self.with_linker_inplace(gof.CLinker())
def test_fill(self): def test_fill(self):
x = Tensor('float64', [0, 0])('x') x = Tensor('float64', [0, 0])('x')
y = Tensor('float64', [1, 1])('y') y = Tensor('float64', [1, 1])('y')
e = Elemwise(Second(transfer_type(0)), {0:0})(x, y) e = Elemwise(Second(transfer_type(0)), {0:0})(x, y)
f = gof.CLinker(Env([x, y], [e])).make_function() f = gof.CLinker().accept(Env([x, y], [e])).make_function()
xv = numpy.ones((5, 5)) xv = numpy.ones((5, 5))
yv = numpy.random.rand(1, 1) yv = numpy.random.rand(1, 1)
f(xv, yv) f(xv, yv)
...@@ -101,7 +101,7 @@ class _test_Broadcast(unittest.TestCase): ...@@ -101,7 +101,7 @@ class _test_Broadcast(unittest.TestCase):
x = Tensor('float64', [0, 0, 0, 0, 0])('x') x = Tensor('float64', [0, 0, 0, 0, 0])('x')
y = Tensor('float64', [0, 0, 0, 0, 0])('y') y = Tensor('float64', [0, 0, 0, 0, 0])('y')
e = Elemwise(add)(x, y) e = Elemwise(add)(x, y)
f = gof.CLinker(Env([x, y], [e])).make_function() f = gof.CLinker().accept(Env([x, y], [e])).make_function()
xv = numpy.random.rand(2, 2, 2, 2, 2) xv = numpy.random.rand(2, 2, 2, 2, 2)
yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2) yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2)
zv = xv + yv zv = xv + yv
...@@ -110,7 +110,7 @@ class _test_Broadcast(unittest.TestCase): ...@@ -110,7 +110,7 @@ class _test_Broadcast(unittest.TestCase):
def test_same_inputs(self): def test_same_inputs(self):
x = Tensor('float64', [0, 0])('x') x = Tensor('float64', [0, 0])('x')
e = Elemwise(add)(x, x) e = Elemwise(add)(x, x)
f = gof.CLinker(Env([x], [e])).make_function() f = gof.CLinker().accept(Env([x], [e])).make_function()
xv = numpy.random.rand(2, 2) xv = numpy.random.rand(2, 2)
zv = xv + xv zv = xv + xv
assert (f(xv) == zv).all() assert (f(xv) == zv).all()
...@@ -129,7 +129,7 @@ class _test_CAReduce(unittest.TestCase): ...@@ -129,7 +129,7 @@ class _test_CAReduce(unittest.TestCase):
x = Tensor('float64', [(entry == 1) for entry in xsh])('x') x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
e = CAReduce(add, axis = tosum)(x) e = CAReduce(add, axis = tosum)(x)
if tosum is None: tosum = range(len(xsh)) if tosum is None: tosum = range(len(xsh))
f = linker(Env([x], [e])).make_function() f = copy(linker).accept(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
zv = xv zv = xv
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
...@@ -137,10 +137,10 @@ class _test_CAReduce(unittest.TestCase): ...@@ -137,10 +137,10 @@ class _test_CAReduce(unittest.TestCase):
self.failUnless((numpy.abs(f(xv) - zv) < 1e-10).all()) self.failUnless((numpy.abs(f(xv) - zv) < 1e-10).all())
def test_perform(self): def test_perform(self):
self.with_linker(gof.PerformLinker) self.with_linker(gof.PerformLinker())
def test_c(self): def test_c(self):
self.with_linker(gof.CLinker) self.with_linker(gof.CLinker())
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,7 +17,7 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -17,7 +17,7 @@ class _test_ScalarOps(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
g = Env([x, y], [e]) g = Env([x, y], [e])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
...@@ -30,7 +30,7 @@ class _test_composite(unittest.TestCase): ...@@ -30,7 +30,7 @@ class _test_composite(unittest.TestCase):
c = C.make_node(x, y) c = C.make_node(x, y)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
g = Env([x, y], [c.out]) g = Env([x, y], [c.out])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
def test_with_constants(self): def test_with_constants(self):
...@@ -41,7 +41,7 @@ class _test_composite(unittest.TestCase): ...@@ -41,7 +41,7 @@ class _test_composite(unittest.TestCase):
assert "70.0" in c.op.c_code(c, 'dummy', ['x', 'y'], ['z'], dict(id = 0)) assert "70.0" in c.op.c_code(c, 'dummy', ['x', 'y'], ['z'], dict(id = 0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
g = Env([x, y], [c.out]) g = Env([x, y], [c.out])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 36.0 assert fn(1.0, 2.0) == 36.0
def test_many_outputs(self): def test_many_outputs(self):
...@@ -53,79 +53,79 @@ class _test_composite(unittest.TestCase): ...@@ -53,79 +53,79 @@ class _test_composite(unittest.TestCase):
c = C.make_node(x, y, z) c = C.make_node(x, y, z)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0)) # print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
g = Env([x, y, z], c.outputs) g = Env([x, y, z], c.outputs)
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
class _test_logical(unittest.TestCase): class _test_logical(unittest.TestCase):
def test_gt(self): def test_gt(self):
x, y, z = inputs() x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x > y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x > y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)): for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a>b)) self.failUnless(fn(a,b) == (a>b))
def test_lt(self): def test_lt(self):
x, y, z = inputs() x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x < y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x < y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)): for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a<b)) self.failUnless(fn(a,b) == (a<b))
def test_le(self): def test_le(self):
x, y, z = inputs() x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x <= y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x <= y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)): for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a<=b)) self.failUnless(fn(a,b) == (a<=b))
def test_ge(self): def test_ge(self):
x, y, z = inputs() x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x >= y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x >= y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)): for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a>=b)) self.failUnless(fn(a,b) == (a>=b))
def test_eq(self): def test_eq(self):
x, y, z = inputs() x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [eq(x,y)])).make_function() fn = gof.DualLinker().accept(Env([x,y], [eq(x,y)])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)): for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a==b)) self.failUnless(fn(a,b) == (a==b))
def test_neq(self): def test_neq(self):
x, y, z = inputs() x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [neq(x,y)])).make_function() fn = gof.DualLinker().accept(Env([x,y], [neq(x,y)])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)): for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a!=b)) self.failUnless(fn(a,b) == (a!=b))
def test_or(self): def test_or(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [x|y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x|y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a|b), (a,b)) self.failUnless(fn(a,b) == (a|b), (a,b))
def test_xor(self): def test_xor(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [x^y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x^y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a ^ b), (a,b)) self.failUnless(fn(a,b) == (a ^ b), (a,b))
def test_and(self): def test_and(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [and_(x, y)])).make_function() fn = gof.DualLinker().accept(Env([x,y], [and_(x, y)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a & b), (a,b)) self.failUnless(fn(a,b) == (a & b), (a,b))
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [x & y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x & y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a & b), (a,b)) self.failUnless(fn(a,b) == (a & b), (a,b))
def test_not(self): def test_not(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [invert(x)])).make_function() fn = gof.DualLinker().accept(Env([x,y], [invert(x)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == ~a, (a,)) self.failUnless(fn(a,b) == ~a, (a,))
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [~x])).make_function() fn = gof.DualLinker().accept(Env([x,y], [~x])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == ~a, (a,)) self.failUnless(fn(a,b) == ~a, (a,))
......
...@@ -56,7 +56,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_ ...@@ -56,7 +56,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
try: try:
f = function(inputrs, node.outputs, f = function(inputrs, node.outputs,
linker = lambda env, **kwargs: gof.DualLinker(env, checker = _numpy_checker, **kwargs), linker = 'c&py', ##lambda env, **kwargs: gof.DualLinker(env, checker = _numpy_checker, **kwargs),
unpack_single = False, unpack_single = False,
optimizer = None) optimizer = None)
except: except:
...@@ -115,7 +115,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_ ...@@ -115,7 +115,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
try: try:
f = function(inputrs, node.outputs, f = function(inputrs, node.outputs,
linker = lambda env, **kwargs: gof.DualLinker(env, checker = _numpy_checker, **kwargs), linker = 'c&py', #lambda env, **kwargs: gof.DualLinker(env, checker = _numpy_checker, **kwargs),
unpack_single = False, unpack_single = False,
optimizer = None) optimizer = None)
except: except:
...@@ -1045,7 +1045,7 @@ class T_add(unittest.TestCase): ...@@ -1045,7 +1045,7 @@ class T_add(unittest.TestCase):
("*", lambda x,y: x*y), ("*", lambda x,y: x*y),
("/", lambda x,y: x/y)) ("/", lambda x,y: x/y))
for s, fn in tests: for s, fn in tests:
f = function([a,b], [fn(a, b)], linker = gof.CLinker) f = function([a,b], [fn(a, b)], linker = 'c')
self.failUnless(numpy.all(fn(a.data, b.data) == f(a.data, b.data))) self.failUnless(numpy.all(fn(a.data, b.data) == f(a.data, b.data)))
def test_grad_scalar_l(self): def test_grad_scalar_l(self):
...@@ -1354,9 +1354,9 @@ class t_gemm(unittest.TestCase): ...@@ -1354,9 +1354,9 @@ class t_gemm(unittest.TestCase):
else: else:
self.failIf(numpy.all(z_orig == z)) self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, gof.cc.OpWiseCLinker) cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker) cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, gof.link.PerformLinker) cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self): def test0a(self):
Gemm.debug = True Gemm.debug = True
...@@ -1456,7 +1456,7 @@ class t_gemm(unittest.TestCase): ...@@ -1456,7 +1456,7 @@ class t_gemm(unittest.TestCase):
B = self.rand(4,5)[:,:4] B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4] C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker,dt='float64'): def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b] z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy() z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b) z_after = self._gemm(z, a, x, y, b)
...@@ -1699,7 +1699,7 @@ class _test_grad(unittest.TestCase): ...@@ -1699,7 +1699,7 @@ class _test_grad(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 1:
unittest.main() unittest.main()
else: else:
suite = unittest.TestLoader() suite = unittest.TestLoader()
......
...@@ -140,7 +140,7 @@ class _test_CLinker(unittest.TestCase): ...@@ -140,7 +140,7 @@ class _test_CLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(Env([x, y, z], [e])) lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
...@@ -158,7 +158,7 @@ class _test_CLinker(unittest.TestCase): ...@@ -158,7 +158,7 @@ class _test_CLinker(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
z = Constant(tdouble, 4.12345678) z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(Env([x, y], [e])) lnk = CLinker().accept(Env([x, y], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9) self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined
...@@ -166,7 +166,7 @@ class _test_CLinker(unittest.TestCase): ...@@ -166,7 +166,7 @@ class _test_CLinker(unittest.TestCase):
def test_single_node(self): def test_single_node(self):
x, y, z = inputs() x, y, z = inputs()
node = add.make_node(x, y) node = add.make_node(x, y)
lnk = CLinker(Env(node.inputs, node.outputs)) lnk = CLinker().accept(Env(node.inputs, node.outputs))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 7.0) == 9) self.failUnless(fn(2.0, 7.0) == 9)
...@@ -174,7 +174,7 @@ class _test_CLinker(unittest.TestCase): ...@@ -174,7 +174,7 @@ class _test_CLinker(unittest.TestCase):
# Testing that duplicate inputs are allowed. # Testing that duplicate inputs are allowed.
x, y, z = inputs() x, y, z = inputs()
e = add(x, x) e = add(x, x)
lnk = CLinker(Env([x, x], [e])) lnk = CLinker().accept(Env([x, x], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0) == 4) self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
...@@ -183,7 +183,7 @@ class _test_CLinker(unittest.TestCase): ...@@ -183,7 +183,7 @@ class _test_CLinker(unittest.TestCase):
# Testing that duplicates are allowed inside the graph # Testing that duplicates are allowed inside the graph
x, y, z = inputs() x, y, z = inputs()
e = add(mul(y, y), add(x, z)) e = add(mul(y, y), add(x, z))
lnk = CLinker(Env([x, y, z], [e])) lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0) self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
...@@ -194,7 +194,7 @@ class _test_OpWiseCLinker(unittest.TestCase): ...@@ -194,7 +194,7 @@ class _test_OpWiseCLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(Env([x, y, z], [e])) lnk = OpWiseCLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
...@@ -202,7 +202,7 @@ class _test_OpWiseCLinker(unittest.TestCase): ...@@ -202,7 +202,7 @@ class _test_OpWiseCLinker(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x') x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z)) e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker(Env([y, z], [e])) lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
res = fn(1.5, 3.0) res = fn(1.5, 3.0)
self.failUnless(res == 15.3, res) self.failUnless(res == 15.3, res)
...@@ -220,7 +220,7 @@ class _test_DualLinker(unittest.TestCase): ...@@ -220,7 +220,7 @@ class _test_DualLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(Env([x, y, z], [e]), checker = _my_checker) lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0) res = fn(7.2, 1.5, 3.0)
self.failUnless(res == 15.3, res) self.failUnless(res == 15.3, res)
...@@ -229,12 +229,12 @@ class _test_DualLinker(unittest.TestCase): ...@@ -229,12 +229,12 @@ class _test_DualLinker(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
lnk = DualLinker(g, checker = _my_checker) lnk = DualLinker(checker = _my_checker).accept(g)
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(CLinker(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good self.failUnless(CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good
self.failUnless(OpWiseCLinker(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good self.failUnless(OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good
self.failUnless(PerformLinker(g).make_function()(1.0, 2.0, 3.0) == -10.0) # (purposely) wrong self.failUnless(PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0) # (purposely) wrong
try: try:
# this runs OpWiseCLinker and PerformLinker in parallel and feeds # this runs OpWiseCLinker and PerformLinker in parallel and feeds
......
...@@ -70,7 +70,7 @@ def inputs(): ...@@ -70,7 +70,7 @@ def inputs():
return x, y, z return x, y, z
def perform_linker(env): def perform_linker(env):
lnk = PerformLinker(env) lnk = PerformLinker().accept(env)
return lnk return lnk
def Env(inputs, outputs): def Env(inputs, outputs):
......
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
from type import Type from type import Type
from graph import Result, Apply, Constant from graph import Result, Apply, Constant
from op import Op from op import Op, Macro
from opt import * from opt import *
from env import Env from env import Env
from toolbox import * from toolbox import *
...@@ -415,6 +415,38 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -415,6 +415,38 @@ class _test_MergeOptimizer(unittest.TestCase):
class _test_ExpandMacro(unittest.TestCase):
def test_straightforward(self):
class Macro1(Macro):
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
return [op1(y, x)]
x, y, z = inputs()
e = Macro1()(x, y)
g = Env([x, y], [e])
print g
expand_macros.optimize(g)
print g
def test_loopy(self):
class Macro1(Macro):
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
return [Macro1()(y, x)]
x, y, z = inputs()
e = Macro1()(x, y)
g = Env([x, y], [e])
print g
#expand_macros.optimize(g)
TopDownOptimizer(ExpandMacro(), ignore_newtrees = True).optimize(g)
print g
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -339,10 +339,16 @@ class CLinker(link.Linker): ...@@ -339,10 +339,16 @@ class CLinker(link.Linker):
associated to it during the computation (to avoid reusing it). associated to it during the computation (to avoid reusing it).
""" """
def __init__(self, env, no_recycling = []): def __init__(self):
self.env = None
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
raise Exception("Cannot accept from a Linker that is already tied to another Env.")
self.env = env self.env = env
self.fetch_results() self.fetch_results()
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self
def fetch_results(self): def fetch_results(self):
""" """
...@@ -771,10 +777,16 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -771,10 +777,16 @@ class OpWiseCLinker(link.LocalLinker):
associated to it during the computation (to avoid reusing it). associated to it during the computation (to avoid reusing it).
""" """
def __init__(self, env, fallback_on_perform = True, no_recycling = []): def __init__(self, fallback_on_perform = True):
self.env = env self.env = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
raise Exception("Cannot accept from a Linker that is already tied to another Env.")
self.env = env
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None): def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler, return self.make_all(profiler = profiler,
...@@ -795,7 +807,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -795,7 +807,7 @@ class OpWiseCLinker(link.LocalLinker):
try: try:
e = Env(*graph.clone(node.inputs, node.outputs)) e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes e.toposort = lambda: e.nodes
cl = CLinker(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling]) cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
thunk, node_input_filters, node_output_filters = cl.make_thunk( thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage, input_storage = node_input_storage,
output_storage = node_output_storage) output_storage = node_output_storage)
...@@ -848,7 +860,7 @@ class DualLinker(link.Linker): ...@@ -848,7 +860,7 @@ class DualLinker(link.Linker):
function. function.
""" """
def __init__(self, env, checker = _default_checker, no_recycling = []): def __init__(self, checker = _default_checker):
""" """
Initialize a DualLinker. Initialize a DualLinker.
...@@ -871,17 +883,23 @@ class DualLinker(link.Linker): ...@@ -871,17 +883,23 @@ class DualLinker(link.Linker):
If a Result is in no_recycling, CLinker will clear the output storage If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it). associated to it during the computation (to avoid reusing it).
""" """
self.env = env self.env = None
self.checker = checker self.checker = checker
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
raise Exception("Cannot accept from a Linker that is already tied to another Env.")
self.env = env
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self
def make_thunk(self, **kwargs): def make_thunk(self, **kwargs):
env = self.env env = self.env
no_recycling = self.no_recycling no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = link.PerformLinker(env, no_recycling = no_recycling).make_all(**kwargs) _f, i1, o1, thunks1, order1 = link.PerformLinker().accept(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker(env, no_recycling = no_recycling).make_all(**kwargs) _f, i2, o2, thunks2, order2 = OpWiseCLinker().accept(env, no_recycling = no_recycling).make_all(**kwargs)
def f(): def f():
for input1, input2 in zip(i1, i2): for input1, input2 in zip(i1, i2):
......
...@@ -196,9 +196,15 @@ class PerformLinker(LocalLinker): ...@@ -196,9 +196,15 @@ class PerformLinker(LocalLinker):
the L{Env} in the order given by L{Env.toposort}. the L{Env} in the order given by L{Env.toposort}.
""" """
def __init__(self, env, no_recycling = []): def __init__(self):
self.env = None
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
raise Exception("Cannot accept from a Linker that is already tied to another Env.")
self.env = env self.env = env
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None): def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler, return self.make_all(profiler = profiler,
......
...@@ -32,6 +32,8 @@ class object2(object): ...@@ -32,6 +32,8 @@ class object2(object):
class scratchpad: class scratchpad:
def clear(self): def clear(self):
self.__dict__.clear() self.__dict__.clear()
def __update__(self, other):
self.__dict__.update(other.__dict__)
def __str__(self): def __str__(self):
print "scratch" + str(self.__dict__) print "scratch" + str(self.__dict__)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论