提交 1a1dbe60 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add CDataType to return a C pointer.

This handles Capsules, Cobjects and python 3 compat.
上级 c3fd1cb8
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
from theano.compat import PY3
from theano.gof import utils from theano.gof import utils
from theano.gof.utils import MethodNotDefined, object2 from theano.gof.utils import MethodNotDefined, object2
from theano.gof import graph from theano.gof import graph
...@@ -470,3 +472,85 @@ class Generic(SingletonType): ...@@ -470,3 +472,85 @@ class Generic(SingletonType):
return self.__class__.__name__ return self.__class__.__name__
generic = Generic() generic = Generic()
class CDataType(Type):
"""
Represents opaque C data to be passed around.
"""
def __init__(self, ctype, freefunc=None):
self.ctype = ctype
self.freefunc = freefunc
def __eq__(self, other):
return (type(self) == type(other) and
self.ctype == other.ctype,
self.freefunc == other.freefunc)
def __hash__(self):
return hash((type(self), self.ctype, self.freefunc))
def filter(self, data, strict=False, allow_downcast=None):
if data is not None:
raise TypeError("only None is valid")
def is_valid_value(self, a):
return a is None
def c_declare(self, name, sub, check_input=True):
return """
%(ctype)s %(name)s;
""" % dict(ctype=self.ctype, name=name)
def c_init(self, name, sub):
return "%(name)s = NULL;" % dict(name=name)
def c_extract(self, name, sub, check_input=True):
if PY3:
s = """
%(name)s = (%(ctype)s)PyCapsule_GetPointer(py_%(name)s, NULL);
if (%(name)s == NULL) %(fail)s
"""
else:
s = """
%(name)s = (%(ctype)s)PyCObject_AsVoidPtr(py_%(name)s);
"""
return s % dict(name=name, ctype=self.ctype, fail=sub['fail'])
def c_cleanup(self, name, sub):
if self.freefunc is not None:
return "%(freefunc)s(%(name)s);" % dict(freefunc=self.freefunc,
name=name)
else:
return ""
def c_sync(self, name, sub):
freefunc = self.freefunc
if freefunc is None:
freefunc = "NULL"
s = """
Py_XDECREF(py_%(name)s);
if (%(name)s == NULL) {
py_%(name)s = Py_None;
Py_INCREF(py_%(name)s);
} else """
if PY3:
s += """{
py_%(name)s = PyCapsule_New((void *)%(name)s, NULL,
(void (*)(void *))%(freefunc)s);
}"""
else:
s += """{
py_%(name)s = PyCObject_FromVoidPtr((void *)%(name)s,
(void (*)(void *))%(freefunc)s);
}"""
s += """
if (py_%(name)s != NULL) { %(name)s = NULL; }
"""
return s % dict(name=name, freefunc=freefunc)
def c_code_cache_version(self):
return (0,)
def __str__(self):
return self.__class__.__name__
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论