提交 d596f2f0 authored 作者: lamblin's avatar lamblin

Merge pull request #593 from nouiz/c_code_cache_version

C code cache version
...@@ -63,6 +63,15 @@ New Features ...@@ -63,6 +63,15 @@ New Features
can't be cached as there is no c_code_cache_version() function to at can't be cached as there is no c_code_cache_version() function to at
least of of those Ops. (Frederic B.) least of of those Ops. (Frederic B.)
* CPU alloc now always generate c code (Pascal L.) * CPU alloc now always generate c code (Pascal L.)
* New Theano flag cmodule.warn_no_version=False. When True, warn when an op
with c code is not versionned. This force to recompile it everytimes.
(Frédéric B.)
* Made a few Ops with c code versionned to reduce compilation time.
(Frédéric B, Pascal L.)
* c code reuse preallocated output(only done by Scan) (Pascal L.)
* gc of intermediate result during theano function call for op with c code
(Pascal L.)
Sparse Sparse
* Implement theano.sparse.mul(sparse1, sparse2) when both inputs don't * Implement theano.sparse.mul(sparse1, sparse2) when both inputs don't
......
...@@ -47,6 +47,9 @@ class BROKEN_ON_PURPOSE_Add(gof.Op): ...@@ -47,6 +47,9 @@ class BROKEN_ON_PURPOSE_Add(gof.Op):
else: else:
out[0] = z out[0] = z
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
a, b = inp a, b = inp
z, = out z, = out
...@@ -130,6 +133,9 @@ class WeirdBrokenOp(gof.Op): ...@@ -130,6 +133,9 @@ class WeirdBrokenOp(gof.Op):
else: else:
raise ValueError(self.behaviour) raise ValueError(self.behaviour)
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
a, = inp a, = inp
z, = out z, = out
...@@ -553,9 +559,7 @@ class Test_check_isfinite(unittest.TestCase): ...@@ -553,9 +559,7 @@ class Test_check_isfinite(unittest.TestCase):
return return
class Test_preallocated_output(unittest.TestCase): class BrokenCImplementationAdd(gof.Op):
class BrokenCImplementationAdd(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -579,6 +583,9 @@ class Test_preallocated_output(unittest.TestCase): ...@@ -579,6 +583,9 @@ class Test_preallocated_output(unittest.TestCase):
print 'out[0] was:', out[0] print 'out[0] was:', out[0]
out[0] = z out[0] = z
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
a, b = inp a, b = inp
z, = out z, = out
...@@ -643,10 +650,13 @@ class Test_preallocated_output(unittest.TestCase): ...@@ -643,10 +650,13 @@ class Test_preallocated_output(unittest.TestCase):
} }
""" % dict(locals(), **sub) """ % dict(locals(), **sub)
class Test_preallocated_output(unittest.TestCase):
def test_f_contiguous(self): def test_f_contiguous(self):
a = theano.tensor.fmatrix('a') a = theano.tensor.fmatrix('a')
b = theano.tensor.fmatrix('b') b = theano.tensor.fmatrix('b')
z = self.BrokenCImplementationAdd()(a, b) z = BrokenCImplementationAdd()(a, b)
# Needed so that z is not the output of the graph # Needed so that z is not the output of the graph
out = theano.tensor.dot(z, numpy.eye(7)) out = theano.tensor.dot(z, numpy.eye(7))
......
...@@ -9,10 +9,12 @@ from theano.gof.op import Op ...@@ -9,10 +9,12 @@ from theano.gof.op import Op
from theano.gof import env from theano.gof import env
from theano.gof import toolbox from theano.gof import toolbox
def as_variable(x): def as_variable(x):
assert isinstance(x, Variable) assert isinstance(x, Variable)
return x return x
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
...@@ -58,12 +60,19 @@ class TDouble(Type): ...@@ -58,12 +60,19 @@ class TDouble(Type):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return () return (1,)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
tdouble = TDouble() tdouble = TDouble()
def double(name): def double(name):
return Variable(tdouble, None, None, name = name) return Variable(tdouble, None, None, name=name)
class MyOp(Op): class MyOp(Op):
...@@ -84,16 +93,26 @@ class MyOp(Op): ...@@ -84,16 +93,26 @@ class MyOp(Op):
def __str__(self): def __str__(self):
return self.name return self.name
def __eq__(self, other):
return (type(self) == type(other) and
self.name == other.name and
self.nin == other.nin)
def __hash__(self):
return hash(type(self)) ^ hash(self.name) ^ hash(self.nin)
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
out, = out_ out, = out_
out[0] = self.impl(*inputs) out[0] = self.impl(*inputs)
def c_code_cache_version(self): def c_code_cache_version(self):
return () return ()
class Unary(MyOp): #class Unary(MyOp):
def __init__(self): # def __init__(self):
MyOp.__init__(self, 1, self.__class__.__name__) # MyOp.__init__(self, 1, self.__class__.__name__)
class Binary(MyOp): class Binary(MyOp):
def __init__(self): def __init__(self):
...@@ -105,37 +124,45 @@ class Add(Binary): ...@@ -105,37 +124,45 @@ class Add(Binary):
x, y = inp x, y = inp
z, = out z, = out
return "%(z)s = %(x)s + %(y)s;" % locals() return "%(z)s = %(x)s + %(y)s;" % locals()
def impl(self, x, y): def impl(self, x, y):
return x + y return x + y
add = Add() add = Add()
class Sub(Binary): class Sub(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def impl(self, x, y): def impl(self, x, y):
return -10 # erroneous (most of the time) return -10 # erroneous (most of the time)
sub = Sub() sub = Sub()
class Mul(Binary): class Mul(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
return "%(z)s = %(x)s * %(y)s;" % locals() return "%(z)s = %(x)s * %(y)s;" % locals()
def impl(self, x, y): def impl(self, x, y):
return x * y return x * y
mul = Mul() mul = Mul()
class Div(Binary): class Div(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def impl(self, x, y): def impl(self, x, y):
return x / y return x / y
div = Div() div = Div()
def inputs(): def inputs():
x = double('x') x = double('x')
y = double('y') y = double('y')
...@@ -159,6 +186,7 @@ def test_clinker_straightforward(): ...@@ -159,6 +186,7 @@ def test_clinker_straightforward():
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0 assert fn(2.0, 2.0, 2.0) == 2.0
def test_clinker_literal_inlining(): def test_clinker_literal_inlining():
x, y, z = inputs() x, y, z = inputs()
z = Constant(tdouble, 4.12345678) z = Constant(tdouble, 4.12345678)
...@@ -171,6 +199,7 @@ def test_clinker_literal_inlining(): ...@@ -171,6 +199,7 @@ def test_clinker_literal_inlining():
print code print code
assert "4.12345678" in code # we expect the number to be inlined assert "4.12345678" in code # we expect the number to be inlined
def test_clinker_single_node(): def test_clinker_single_node():
x, y, z = inputs() x, y, z = inputs()
node = add.make_node(x, y) node = add.make_node(x, y)
...@@ -178,6 +207,7 @@ def test_clinker_single_node(): ...@@ -178,6 +207,7 @@ def test_clinker_single_node():
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 7.0) == 9 assert fn(2.0, 7.0) == 9
def test_clinker_dups(): def test_clinker_dups():
# Testing that duplicate inputs are allowed. # Testing that duplicate inputs are allowed.
x, y, z = inputs() x, y, z = inputs()
...@@ -187,6 +217,7 @@ def test_clinker_dups(): ...@@ -187,6 +217,7 @@ def test_clinker_dups():
assert fn(2.0, 2.0) == 4 assert fn(2.0, 2.0) == 4
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner(): def test_clinker_dups_inner():
# Testing that duplicates are allowed inside the graph # Testing that duplicates are allowed inside the graph
x, y, z = inputs() x, y, z = inputs()
...@@ -196,7 +227,6 @@ def test_clinker_dups_inner(): ...@@ -196,7 +227,6 @@ def test_clinker_dups_inner():
assert fn(1.0, 2.0, 3.0) == 8.0 assert fn(1.0, 2.0, 3.0) == 8.0
###################### ######################
# Test OpWiseCLinker # # Test OpWiseCLinker #
###################### ######################
...@@ -208,9 +238,10 @@ def test_opwiseclinker_straightforward(): ...@@ -208,9 +238,10 @@ def test_opwiseclinker_straightforward():
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0 assert fn(2.0, 2.0, 2.0) == 2.0
def test_opwiseclinker_constant(): def test_opwiseclinker_constant():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x') x = Constant(tdouble, 7.2, name='x')
e = add(mul(x, y), mul(y, z)) e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e])) lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
...@@ -218,13 +249,14 @@ def test_opwiseclinker_constant(): ...@@ -218,13 +249,14 @@ def test_opwiseclinker_constant():
assert res == 15.3 assert res == 15.3
class MyExc(Exception): class MyExc(Exception):
pass pass
def _my_checker(x, y): def _my_checker(x, y):
if x[0] != y[0]: if x[0] != y[0]:
raise MyExc("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]}) raise MyExc("Output mismatch.",
{'performlinker': x[0], 'clinker': y[0]})
################### ###################
...@@ -234,21 +266,26 @@ def _my_checker(x, y): ...@@ -234,21 +266,26 @@ def _my_checker(x, y):
def test_duallinker_straightforward(): def test_duallinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e])) lnk = DualLinker(checker=_my_checker).accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0) res = fn(7.2, 1.5, 3.0)
assert res == 15.3 assert res == 15.3
def test_duallinker_mismatch(): def test_duallinker_mismatch():
x, y, z = inputs() x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python # sub is correct in C but erroneous in Python
e = sub(mul(x, y), mul(y, z))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
lnk = DualLinker(checker = _my_checker).accept(g) lnk = DualLinker(checker=_my_checker).accept(g)
fn = lnk.make_function() fn = lnk.make_function()
assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good # good
assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0
assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0 # (purposely) wrong # good
assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0
# (purposely) wrong
assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0
try: try:
# this runs OpWiseCLinker and PerformLinker in parallel and feeds # this runs OpWiseCLinker and PerformLinker in parallel and feeds
...@@ -268,17 +305,19 @@ class AddFail(Binary): ...@@ -268,17 +305,19 @@ class AddFail(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
fail=sub['fail'] fail = sub['fail']
return """%(z)s = %(x)s + %(y)s; return """%(z)s = %(x)s + %(y)s;
PyErr_SetString(PyExc_RuntimeError, "failing here"); PyErr_SetString(PyExc_RuntimeError, "failing here");
%(fail)s;""" % locals() %(fail)s;""" % locals()
def impl(self, x, y): def impl(self, x, y):
return x + y return x + y
add_fail = AddFail() add_fail = AddFail()
def test_fail_error(): def test_fail_error():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x') x = Constant(tdouble, 7.2, name='x')
e = add_fail(mul(x, y), mul(y, z)) e = add_fail(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e])) lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
...@@ -286,7 +325,5 @@ def test_fail_error(): ...@@ -286,7 +325,5 @@ def test_fail_error():
res = fn(1.5, 3.0) res = fn(1.5, 3.0)
except RuntimeError: except RuntimeError:
print 'Yay, TEST PASSED' print 'Yay, TEST PASSED'
return #test passed return # test passed
assert 0 #test failed assert 0 # test failed
...@@ -348,6 +348,9 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -348,6 +348,9 @@ class TestComputeTestValue(unittest.TestCase):
output = input.type() output = input.type()
return Apply(self, [input], [output]) return Apply(self, [input], [output])
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
x, = inputs x, = inputs
z, = outputs z, = outputs
......
...@@ -449,6 +449,9 @@ class Generic(SingletonType): ...@@ -449,6 +449,9 @@ class Generic(SingletonType):
Py_INCREF(py_%(name)s); Py_INCREF(py_%(name)s);
""" % locals() """ % locals()
def c_code_cache_version(self):
return (1,)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论