提交 b0324504 authored 作者: Frederic's avatar Frederic

Make a op class pickable by being at the top level of the module.

This is needed to cache the c code.
上级 37dbae29
......@@ -559,103 +559,104 @@ class Test_check_isfinite(unittest.TestCase):
return
class Test_preallocated_output(unittest.TestCase):
class BrokenCImplementationAdd(gof.Op):
def __eq__(self, other):
return type(self) == type(other)
class BrokenCImplementationAdd(gof.Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def __hash__(self):
return hash(type(self))
def make_node(self, a, b):
a = theano.tensor.as_tensor_variable(a)
b = theano.tensor.as_tensor_variable(b)
assert a.type.dtype == 'float32'
assert a.type.dtype == b.type.dtype
assert a.type.ndim == 2
r = gof.Apply(self, [a, b], [a.type()])
return r
def make_node(self, a, b):
a = theano.tensor.as_tensor_variable(a)
b = theano.tensor.as_tensor_variable(b)
assert a.type.dtype == 'float32'
assert a.type.dtype == b.type.dtype
assert a.type.ndim == 2
r = gof.Apply(self, [a, b], [a.type()])
return r
def perform(self, node, inp, out_):
print 'executing python perform'
a, b = inp
out, = out_
z = a + b
print 'out[0] was:', out[0]
out[0] = z
def perform(self, node, inp, out_):
print 'executing python perform'
a, b = inp
out, = out_
z = a + b
print 'out[0] was:', out[0]
out[0] = z
def c_code_cache_version(self):
return (1,)
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inp, out, sub):
a, b = inp
z, = out
debug = 0
return """
//printf("executing c_code\\n");
if (%(a)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); %(fail)s;}
if (%(b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;}
def c_code(self, node, name, inp, out, sub):
a, b = inp
z, = out
debug = 0
return """
//printf("executing c_code\\n");
if (%(a)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); %(fail)s;}
if (%(b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;}
if (%(a)s->descr->type_num != PyArray_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_FLOAT"); %(fail)s;}
if (%(a)s->descr->type_num != PyArray_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_FLOAT"); %(fail)s;}
if (%(b)s->descr->type_num != PyArray_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_FLOAT"); %(fail)s;}
if (%(b)s->descr->type_num != PyArray_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_FLOAT"); %(fail)s;}
if (%(a)s->dimensions[0] != %(a)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "a is not square"); %(fail)s;}
if (%(a)s->dimensions[0] != %(a)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "a is not square"); %(fail)s;}
if (%(b)s->dimensions[0] != %(b)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "b is not square"); %(fail)s;}
if (%(b)s->dimensions[0] != %(b)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "b is not square"); %(fail)s;}
if (%(a)s->dimensions[0] != %(b)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a and b have different dimensions"); %(fail)s;}
if (%(a)s->dimensions[0] != %(b)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a and b have different dimensions"); %(fail)s;}
// We do not check for c_contiguous property here
if (%(debug)s)
{
if (!%(z)s)
printf("%(z)s is not there, %%p \\n", %(z)s);
else if (%(z)s->dimensions[0] != %(b)s->dimensions[0])
printf("Dimension 0 mismatch for %(z)s and %(b)s\\n");
else if (%(z)s->dimensions[1] != %(b)s->dimensions[1])
printf("Dimension 1 mismatch for %(z)s and %(b)s\\n");
else
printf("Reusing %(z)s\\n");
}
// We do not check for c_contiguous property here
if (%(debug)s)
{
if (!%(z)s)
printf("%(z)s is not there, %%p \\n", %(z)s);
else if (%(z)s->dimensions[0] != %(b)s->dimensions[0])
printf("Dimension 0 mismatch for %(z)s and %(b)s\\n");
else if (%(z)s->dimensions[1] != %(b)s->dimensions[1])
printf("Dimension 1 mismatch for %(z)s and %(b)s\\n");
else
printf("Reusing %(z)s\\n");
}
if ((!%(z)s)
|| (%(z)s->dimensions[0] != %(b)s->dimensions[0])
|| (%(z)s->dimensions[1] != %(b)s->dimensions[1])
)
{
Py_XDECREF(%(z)s);
npy_intp dims[] = {0, 0};
dims[0] = %(b)s->dimensions[0];
dims[1] = %(b)s->dimensions[1];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num);
}
if ((!%(z)s)
|| (%(z)s->dimensions[0] != %(b)s->dimensions[0])
|| (%(z)s->dimensions[1] != %(b)s->dimensions[1])
)
{
Py_XDECREF(%(z)s);
npy_intp dims[] = {0, 0};
dims[0] = %(b)s->dimensions[0];
dims[1] = %(b)s->dimensions[1];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num);
}
// Let us assume that %(z)s is c_contiguous
// Let us assume that %(z)s is c_contiguous
{
dtype_%(z)s * z = ((dtype_%(z)s*)(PyArray_GETPTR2(%(z)s,0,0)));
for (int i=0; i<%(b)s->dimensions[0]; i++)
{
dtype_%(z)s * z = ((dtype_%(z)s*)(PyArray_GETPTR2(%(z)s,0,0)));
for (int i=0; i<%(b)s->dimensions[0]; i++)
for (int j=0; j<%(b)s->dimensions[1]; j++)
{
for (int j=0; j<%(b)s->dimensions[1]; j++)
{
*z = ((float*)PyArray_GETPTR2(%(a)s, i, j))[0] +
((float*)PyArray_GETPTR2(%(b)s, i, j))[0] ;
z++;
}
*z = ((float*)PyArray_GETPTR2(%(a)s, i, j))[0] +
((float*)PyArray_GETPTR2(%(b)s, i, j))[0] ;
z++;
}
}
""" % dict(locals(), **sub)
}
""" % dict(locals(), **sub)
class Test_preallocated_output(unittest.TestCase):
def test_f_contiguous(self):
a = theano.tensor.fmatrix('a')
b = theano.tensor.fmatrix('b')
z = self.BrokenCImplementationAdd()(a, b)
z = BrokenCImplementationAdd()(a, b)
# Needed so that z is not the output of the graph
out = theano.tensor.dot(z, numpy.eye(7))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论