提交 b0324504 authored 作者: Frederic's avatar Frederic

Make a op class pickable by being at the top level of the module.

This is needed to cache the c code.
上级 37dbae29
...@@ -559,103 +559,104 @@ class Test_check_isfinite(unittest.TestCase): ...@@ -559,103 +559,104 @@ class Test_check_isfinite(unittest.TestCase):
return return
class Test_preallocated_output(unittest.TestCase): class BrokenCImplementationAdd(gof.Op):
def __eq__(self, other):
return type(self) == type(other)
class BrokenCImplementationAdd(gof.Op): def __hash__(self):
def __eq__(self, other): return hash(type(self))
return type(self) == type(other)
def __hash__(self): def make_node(self, a, b):
return hash(type(self)) a = theano.tensor.as_tensor_variable(a)
b = theano.tensor.as_tensor_variable(b)
assert a.type.dtype == 'float32'
assert a.type.dtype == b.type.dtype
assert a.type.ndim == 2
r = gof.Apply(self, [a, b], [a.type()])
return r
def make_node(self, a, b): def perform(self, node, inp, out_):
a = theano.tensor.as_tensor_variable(a) print 'executing python perform'
b = theano.tensor.as_tensor_variable(b) a, b = inp
assert a.type.dtype == 'float32' out, = out_
assert a.type.dtype == b.type.dtype z = a + b
assert a.type.ndim == 2 print 'out[0] was:', out[0]
r = gof.Apply(self, [a, b], [a.type()]) out[0] = z
return r
def perform(self, node, inp, out_):
print 'executing python perform'
a, b = inp
out, = out_
z = a + b
print 'out[0] was:', out[0]
out[0] = z
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
a, b = inp a, b = inp
z, = out z, = out
debug = 0 debug = 0
return """ return """
//printf("executing c_code\\n"); //printf("executing c_code\\n");
if (%(a)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); %(fail)s;} if (%(a)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); %(fail)s;}
if (%(b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (%(b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;}
if (%(a)s->descr->type_num != PyArray_FLOAT) if (%(a)s->descr->type_num != PyArray_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_FLOAT"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_FLOAT"); %(fail)s;}
if (%(b)s->descr->type_num != PyArray_FLOAT) if (%(b)s->descr->type_num != PyArray_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_FLOAT"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_FLOAT"); %(fail)s;}
if (%(a)s->dimensions[0] != %(a)s->dimensions[1]) if (%(a)s->dimensions[0] != %(a)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "a is not square"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a is not square"); %(fail)s;}
if (%(b)s->dimensions[0] != %(b)s->dimensions[1]) if (%(b)s->dimensions[0] != %(b)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "b is not square"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b is not square"); %(fail)s;}
if (%(a)s->dimensions[0] != %(b)s->dimensions[0]) if (%(a)s->dimensions[0] != %(b)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a and b have different dimensions"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a and b have different dimensions"); %(fail)s;}
// We do not check for c_contiguous property here // We do not check for c_contiguous property here
if (%(debug)s) if (%(debug)s)
{ {
if (!%(z)s) if (!%(z)s)
printf("%(z)s is not there, %%p \\n", %(z)s); printf("%(z)s is not there, %%p \\n", %(z)s);
else if (%(z)s->dimensions[0] != %(b)s->dimensions[0]) else if (%(z)s->dimensions[0] != %(b)s->dimensions[0])
printf("Dimension 0 mismatch for %(z)s and %(b)s\\n"); printf("Dimension 0 mismatch for %(z)s and %(b)s\\n");
else if (%(z)s->dimensions[1] != %(b)s->dimensions[1]) else if (%(z)s->dimensions[1] != %(b)s->dimensions[1])
printf("Dimension 1 mismatch for %(z)s and %(b)s\\n"); printf("Dimension 1 mismatch for %(z)s and %(b)s\\n");
else else
printf("Reusing %(z)s\\n"); printf("Reusing %(z)s\\n");
} }
if ((!%(z)s) if ((!%(z)s)
|| (%(z)s->dimensions[0] != %(b)s->dimensions[0]) || (%(z)s->dimensions[0] != %(b)s->dimensions[0])
|| (%(z)s->dimensions[1] != %(b)s->dimensions[1]) || (%(z)s->dimensions[1] != %(b)s->dimensions[1])
) )
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
npy_intp dims[] = {0, 0}; npy_intp dims[] = {0, 0};
dims[0] = %(b)s->dimensions[0]; dims[0] = %(b)s->dimensions[0];
dims[1] = %(b)s->dimensions[1]; dims[1] = %(b)s->dimensions[1];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num); %(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num);
} }
// Let us assume that %(z)s is c_contiguous // Let us assume that %(z)s is c_contiguous
{
dtype_%(z)s * z = ((dtype_%(z)s*)(PyArray_GETPTR2(%(z)s,0,0)));
for (int i=0; i<%(b)s->dimensions[0]; i++)
{ {
dtype_%(z)s * z = ((dtype_%(z)s*)(PyArray_GETPTR2(%(z)s,0,0))); for (int j=0; j<%(b)s->dimensions[1]; j++)
for (int i=0; i<%(b)s->dimensions[0]; i++)
{ {
for (int j=0; j<%(b)s->dimensions[1]; j++) *z = ((float*)PyArray_GETPTR2(%(a)s, i, j))[0] +
{ ((float*)PyArray_GETPTR2(%(b)s, i, j))[0] ;
*z = ((float*)PyArray_GETPTR2(%(a)s, i, j))[0] + z++;
((float*)PyArray_GETPTR2(%(b)s, i, j))[0] ;
z++;
}
} }
} }
""" % dict(locals(), **sub) }
""" % dict(locals(), **sub)
class Test_preallocated_output(unittest.TestCase):
def test_f_contiguous(self): def test_f_contiguous(self):
a = theano.tensor.fmatrix('a') a = theano.tensor.fmatrix('a')
b = theano.tensor.fmatrix('b') b = theano.tensor.fmatrix('b')
z = self.BrokenCImplementationAdd()(a, b) z = BrokenCImplementationAdd()(a, b)
# Needed so that z is not the output of the graph # Needed so that z is not the output of the graph
out = theano.tensor.dot(z, numpy.eye(7)) out = theano.tensor.dot(z, numpy.eye(7))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论