提交 1a1dbe60 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add CDataType to return a C pointer.

This handles Capsules, Cobjects and python 3 compat.
上级 c3fd1cb8
......@@ -2,6 +2,8 @@
__docformat__ = "restructuredtext en"
from theano.compat import PY3
from theano.gof import utils
from theano.gof.utils import MethodNotDefined, object2
from theano.gof import graph
......@@ -250,7 +252,7 @@ class PureType(object):
# If filter_inplace is defined, it will be called instead of
# filter() This is to allow reusing the old allocated memory. As
# of this writing this is used only when we transfer new data to a
# shared variable on the gpu.
# shared variable on the gpu.
#def filter_inplace(value, storage, strict=False, allow_downcast=None)
......@@ -470,3 +472,85 @@ class Generic(SingletonType):
return self.__class__.__name__
generic = Generic()
class CDataType(Type):
"""
Represents opaque C data to be passed around.
"""
def __init__(self, ctype, freefunc=None):
self.ctype = ctype
self.freefunc = freefunc
def __eq__(self, other):
return (type(self) == type(other) and
self.ctype == other.ctype,
self.freefunc == other.freefunc)
def __hash__(self):
return hash((type(self), self.ctype, self.freefunc))
def filter(self, data, strict=False, allow_downcast=None):
if data is not None:
raise TypeError("only None is valid")
def is_valid_value(self, a):
return a is None
def c_declare(self, name, sub, check_input=True):
return """
%(ctype)s %(name)s;
""" % dict(ctype=self.ctype, name=name)
def c_init(self, name, sub):
return "%(name)s = NULL;" % dict(name=name)
def c_extract(self, name, sub, check_input=True):
if PY3:
s = """
%(name)s = (%(ctype)s)PyCapsule_GetPointer(py_%(name)s, NULL);
if (%(name)s == NULL) %(fail)s
"""
else:
s = """
%(name)s = (%(ctype)s)PyCObject_AsVoidPtr(py_%(name)s);
"""
return s % dict(name=name, ctype=self.ctype, fail=sub['fail'])
def c_cleanup(self, name, sub):
if self.freefunc is not None:
return "%(freefunc)s(%(name)s);" % dict(freefunc=self.freefunc,
name=name)
else:
return ""
def c_sync(self, name, sub):
freefunc = self.freefunc
if freefunc is None:
freefunc = "NULL"
s = """
Py_XDECREF(py_%(name)s);
if (%(name)s == NULL) {
py_%(name)s = Py_None;
Py_INCREF(py_%(name)s);
} else """
if PY3:
s += """{
py_%(name)s = PyCapsule_New((void *)%(name)s, NULL,
(void (*)(void *))%(freefunc)s);
}"""
else:
s += """{
py_%(name)s = PyCObject_FromVoidPtr((void *)%(name)s,
(void (*)(void *))%(freefunc)s);
}"""
s += """
if (py_%(name)s != NULL) { %(name)s = NULL; }
"""
return s % dict(name=name, freefunc=freefunc)
def c_code_cache_version(self):
return (0,)
def __str__(self):
return self.__class__.__name__
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论