提交 e31c24cd authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4951 from abergeron/part2

Drop support for python 2.6 in CDataType
......@@ -1024,10 +1024,10 @@ def _lessbroken_deepcopy(a):
"""
# this exists because copy.deepcopy on numpy arrays is broken
# This logic is also in link.py
from theano.gof.type import CDataType
from theano.gof.type import _cdata_type
if type(a) in (numpy.ndarray, numpy.memmap):
rval = a.copy()
elif type(a) is CDataType._cdata_type:
elif type(a) is _cdata_type:
# This is not copyable (and should be used for constant data).
rval = a
else:
......
......@@ -65,7 +65,8 @@ def contains_nan(arr, node=None, var=None):
construction of a boolean array with the same shape as the input array.
"""
if isinstance(arr, theano.gof.type.CDataType._cdata_type):
# This should be a whitelist instead of a blacklist
if isinstance(arr, theano.gof.type._cdata_type):
return False
elif isinstance(arr, np.random.mtrand.RandomState):
return False
......@@ -114,7 +115,7 @@ def contains_inf(arr, node=None, var=None):
boolean array with the same shape as the input array.
"""
if isinstance(arr, theano.gof.type.CDataType._cdata_type):
if isinstance(arr, theano.gof.type._cdata_type):
return False
elif isinstance(arr, np.random.mtrand.RandomState):
return False
......@@ -250,7 +251,7 @@ class NanGuardMode(Mode):
error = True
if big_is_error:
err = False
if isinstance(value, theano.gof.type.CDataType._cdata_type):
if isinstance(value, theano.gof.type._cdata_type):
err = False
elif isinstance(value, np.random.mtrand.RandomState):
err = False
......
......@@ -6,9 +6,11 @@ Defines the `Type` class.
"""
from __future__ import absolute_import, print_function, division
from theano.compat import PY3
import ctypes
from six import string_types
import theano
from theano.gof import utils
from theano.gof.utils import MethodNotDefined, object2
from theano.gof import graph
......@@ -16,7 +18,7 @@ from theano.gof import graph
########
# Type #
########
from theano.gof.op import CLinkerObject
from theano.gof.op import CLinkerObject, Op
__docformat__ = "restructuredtext en"
......@@ -589,6 +591,35 @@ class Generic(SingletonType):
generic = Generic()
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCapsule_Type)).value
class _make_cdata(Op):
__props__ = ('rtype',)
def __init__(self, rtype):
assert isinstance(rtype, CDataType)
self.rtype = rtype
def do_constant_folding(self, node):
return False
def make_node(self, val):
from theano.scalar import as_scalar
from theano import Apply
val = as_scalar(val).astype('uint64')
return Apply(self, [val], [self.rtype()])
def c_code(self, node, name, inputs, outputs, sub):
return """
%(out)s = (%(ctype)s)%(inp)s;
""" % dict(ctype=self.rtype.ctype, out=outputs[0], inp=inputs[0])
def c_code_cache_version(self):
return (0,)
class CDataType(Type):
"""
......@@ -603,26 +634,32 @@ class CDataType(Type):
The type of the pointer (complete with the `*`).
freefunc
A function to call to free the pointer. This function must have a `void`
return and take a single pointer argument.
A function to call to free the pointer. This function must
have a `void` return and take a single pointer argument.
"""
import ctypes
if PY3:
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCapsule_Type)).value
else:
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCObject_Type)).value
del ctypes
def __init__(self, ctype, freefunc=None):
def __init__(self, ctype, freefunc=None, headers=None, header_dirs=None,
libraries=None, lib_dirs=None, extra_support_code=""):
assert isinstance(ctype, string_types)
self.ctype = ctype
if freefunc is not None:
assert isinstance(freefunc, string_types)
self.freefunc = freefunc
if headers is None:
headers = []
self.headers = headers
if header_dirs is None:
header_dirs = []
self.header_dirs = header_dirs
if libraries is None:
libraries = []
self.libraries = libraries
if lib_dirs is None:
lib_dirs = []
self.lib_dirs = lib_dirs
self.extra_support_code = extra_support_code
self._fn = None
def __eq__(self, other):
return (type(self) == type(other) and
......@@ -633,10 +670,21 @@ class CDataType(Type):
return hash((type(self), self.ctype, self.freefunc))
def filter(self, data, strict=False, allow_downcast=None):
if data is not None and not isinstance(data, self._cdata_type):
raise TypeError("expected None or PyCObject/PyCapsule")
if data is not None and not isinstance(data, _cdata_type):
raise TypeError("expected None or a PyCapsule")
return data
def _get_func(self):
from theano.scalar import get_scalar_type
if self._fn is None:
v = get_scalar_type('int64')()
self._fn = theano.function([v], _make_cdata(self)(v), profile=False)
return self._fn
def make_value(self, ptr):
return self._get_func()(ptr)
def c_declare(self, name, sub, check_input=True):
return """
%(ctype)s %(name)s;
......@@ -646,29 +694,20 @@ class CDataType(Type):
return "%(name)s = NULL;" % dict(name=name)
def c_extract(self, name, sub, check_input=True):
if PY3:
s = """
return """
%(name)s = (%(ctype)s)PyCapsule_GetPointer(py_%(name)s, NULL);
if (%(name)s == NULL) %(fail)s
"""
else:
s = """
%(name)s = (%(ctype)s)PyCObject_AsVoidPtr(py_%(name)s);
"""
return s % dict(name=name, ctype=self.ctype, fail=sub['fail'])
""" % dict(name=name, ctype=self.ctype, fail=sub['fail'])
def c_support_code(self):
if PY3:
return """
void _py3_destructor(PyObject *o) {
return """
void _capsule_destructor(PyObject *o) {
void *d = PyCapsule_GetContext(o);
void *p = PyCapsule_GetPointer(o, NULL);
void (*f)(void *) = (void (*)(void *))d;
if (f != NULL) f(p);
}
"""
else:
return ""
""" + self.extra_support_code
def c_sync(self, name, sub):
freefunc = self.freefunc
......@@ -679,11 +718,9 @@ Py_XDECREF(py_%(name)s);
if (%(name)s == NULL) {
py_%(name)s = Py_None;
Py_INCREF(py_%(name)s);
} else """
if PY3:
s += """{
} else {
py_%(name)s = PyCapsule_New((void *)%(name)s, NULL,
_py3_destructor);
_capsule_destructor);
if (py_%(name)s != NULL) {
if (PyCapsule_SetContext(py_%(name)s, (void *)%(freefunc)s) != 0) {
/* This won't trigger a call to freefunc since it could not be
......@@ -693,11 +730,6 @@ if (%(name)s == NULL) {
py_%(name)s = NULL;
}
}
}"""
else:
s += """{
py_%(name)s = PyCObject_FromVoidPtr((void *)%(name)s,
(void (*)(void *))%(freefunc)s);
}"""
if self.freefunc is not None:
s += """
......@@ -710,8 +742,20 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
# free the data for us when released.
return ""
def c_headers(self):
return self.headers
def c_header_dirs(self):
return self.header_dirs
def c_libraries(self):
return self.libraries
def c_lib_dirs(self):
return self.lib_dirs
def c_code_cache_version(self):
return (2, self.ctype, self.freefunc)
return (3,)
def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.ctype)
......@@ -727,4 +771,5 @@ class CDataTypeConstant(graph.Constant):
# There is no way to put the data in the signature, so we
# don't even try
return (self.type,)
CDataType.Constant = CDataTypeConstant
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论