提交 b90abd33 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a test for CDataType and fix a bug that was calling freefunc too often.

上级 4f1edd17
import numpy
import theano
from theano.gof.type import *
from theano import Op, Apply
from theano.tensor import TensorType
from theano.gof.type import CDataType
# todo: test generic
class ProdOp(Op):
__props__ = ()
def make_node(self, i):
return Apply(self, [i], [CDataType('void *', 'py_decref')()])
def c_support_code(self):
return """
void py_decref(void *p) {
Py_XDECREF((PyObject *)p);
}
"""
def c_code(self, node, name, inps, outs, sub):
return """
Py_XDECREF(%(out)s);
%(out)s = (void *)%(inp)s;
Py_INCREF(%(inp)s);
""" % dict(out=outs[0], inp=inps[0])
def c_code_cache_version(self):
return (0,)
class GetOp(Op):
__props__ = ()
def make_node(self, c):
return Apply(self, [c], [TensorType('float32', (False,))()])
def c_support_code(self):
return """
void py_decref(void *p) {
Py_XDECREF((PyObject *)p);
}
"""
def c_code(self, node, name, inps, outs, sub):
return """
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)%(inp)s;
Py_INCREF(%(out)s);
""" % dict(out=outs[0], inp=inps[0])
def c_code_cache_version(self):
return (0,)
def test_cdata():
i = TensorType('float32', (False,))()
c = ProdOp()(i)
i2 = GetOp()(c)
# This should be a passthrough function for vectors
f = theano.function([i], i2)
v = numpy.random.randn(9).astype('float32')
v2 = f(v)
assert (v2 == v).all()
......@@ -530,13 +530,6 @@ class CDataType(Type):
"""
return s % dict(name=name, ctype=self.ctype, fail=sub['fail'])
def c_cleanup(self, name, sub):
if self.freefunc is not None:
return "%(freefunc)s(%(name)s);" % dict(freefunc=self.freefunc,
name=name)
else:
return ""
def c_sync(self, name, sub):
freefunc = self.freefunc
if freefunc is None:
......@@ -554,16 +547,20 @@ if (%(name)s == NULL) {
}"""
else:
s += """{
py_%(name)s = PyCObject_FromVoidPtr((void *)%(name)s,
(void (*)(void *))%(freefunc)s);
py_%(name)s = PyCObject_FromVoidPtr((void *)%(name)s,
(void (*)(void *))%(freefunc)s);
}"""
s += """
if (py_%(name)s != NULL) { %(name)s = NULL; }
if self.freefunc is not None:
s += """
if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
"""
return s % dict(name=name, freefunc=freefunc)
def c_cleanup(self, name, sub):
return ""
def c_code_cache_version(self):
return (0,)
return (1,)
def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.ctype)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论