提交 e31c24cd authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4951 from abergeron/part2

Drop support for python 2.6 in CDataType
...@@ -1024,10 +1024,10 @@ def _lessbroken_deepcopy(a): ...@@ -1024,10 +1024,10 @@ def _lessbroken_deepcopy(a):
""" """
# this exists because copy.deepcopy on numpy arrays is broken # this exists because copy.deepcopy on numpy arrays is broken
# This logic is also in link.py # This logic is also in link.py
from theano.gof.type import CDataType from theano.gof.type import _cdata_type
if type(a) in (numpy.ndarray, numpy.memmap): if type(a) in (numpy.ndarray, numpy.memmap):
rval = a.copy() rval = a.copy()
elif type(a) is CDataType._cdata_type: elif type(a) is _cdata_type:
# This is not copyable (and should be used for constant data). # This is not copyable (and should be used for constant data).
rval = a rval = a
else: else:
......
...@@ -65,7 +65,8 @@ def contains_nan(arr, node=None, var=None): ...@@ -65,7 +65,8 @@ def contains_nan(arr, node=None, var=None):
construction of a boolean array with the same shape as the input array. construction of a boolean array with the same shape as the input array.
""" """
if isinstance(arr, theano.gof.type.CDataType._cdata_type): # This should be a whitelist instead of a blacklist
if isinstance(arr, theano.gof.type._cdata_type):
return False return False
elif isinstance(arr, np.random.mtrand.RandomState): elif isinstance(arr, np.random.mtrand.RandomState):
return False return False
...@@ -114,7 +115,7 @@ def contains_inf(arr, node=None, var=None): ...@@ -114,7 +115,7 @@ def contains_inf(arr, node=None, var=None):
boolean array with the same shape as the input array. boolean array with the same shape as the input array.
""" """
if isinstance(arr, theano.gof.type.CDataType._cdata_type): if isinstance(arr, theano.gof.type._cdata_type):
return False return False
elif isinstance(arr, np.random.mtrand.RandomState): elif isinstance(arr, np.random.mtrand.RandomState):
return False return False
...@@ -250,7 +251,7 @@ class NanGuardMode(Mode): ...@@ -250,7 +251,7 @@ class NanGuardMode(Mode):
error = True error = True
if big_is_error: if big_is_error:
err = False err = False
if isinstance(value, theano.gof.type.CDataType._cdata_type): if isinstance(value, theano.gof.type._cdata_type):
err = False err = False
elif isinstance(value, np.random.mtrand.RandomState): elif isinstance(value, np.random.mtrand.RandomState):
err = False err = False
......
...@@ -6,9 +6,11 @@ Defines the `Type` class. ...@@ -6,9 +6,11 @@ Defines the `Type` class.
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from theano.compat import PY3 import ctypes
from six import string_types from six import string_types
import theano
from theano.gof import utils from theano.gof import utils
from theano.gof.utils import MethodNotDefined, object2 from theano.gof.utils import MethodNotDefined, object2
from theano.gof import graph from theano.gof import graph
...@@ -16,7 +18,7 @@ from theano.gof import graph ...@@ -16,7 +18,7 @@ from theano.gof import graph
######## ########
# Type # # Type #
######## ########
from theano.gof.op import CLinkerObject from theano.gof.op import CLinkerObject, Op
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -589,6 +591,35 @@ class Generic(SingletonType): ...@@ -589,6 +591,35 @@ class Generic(SingletonType):
generic = Generic() generic = Generic()
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCapsule_Type)).value
class _make_cdata(Op):
__props__ = ('rtype',)
def __init__(self, rtype):
assert isinstance(rtype, CDataType)
self.rtype = rtype
def do_constant_folding(self, node):
return False
def make_node(self, val):
from theano.scalar import as_scalar
from theano import Apply
val = as_scalar(val).astype('uint64')
return Apply(self, [val], [self.rtype()])
def c_code(self, node, name, inputs, outputs, sub):
return """
%(out)s = (%(ctype)s)%(inp)s;
""" % dict(ctype=self.rtype.ctype, out=outputs[0], inp=inputs[0])
def c_code_cache_version(self):
return (0,)
class CDataType(Type): class CDataType(Type):
""" """
...@@ -603,26 +634,32 @@ class CDataType(Type): ...@@ -603,26 +634,32 @@ class CDataType(Type):
The type of the pointer (complete with the `*`). The type of the pointer (complete with the `*`).
freefunc freefunc
A function to call to free the pointer. This function must have a `void` A function to call to free the pointer. This function must
return and take a single pointer argument. have a `void` return and take a single pointer argument.
""" """
import ctypes def __init__(self, ctype, freefunc=None, headers=None, header_dirs=None,
if PY3: libraries=None, lib_dirs=None, extra_support_code=""):
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCapsule_Type)).value
else:
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCObject_Type)).value
del ctypes
def __init__(self, ctype, freefunc=None):
assert isinstance(ctype, string_types) assert isinstance(ctype, string_types)
self.ctype = ctype self.ctype = ctype
if freefunc is not None: if freefunc is not None:
assert isinstance(freefunc, string_types) assert isinstance(freefunc, string_types)
self.freefunc = freefunc self.freefunc = freefunc
if headers is None:
headers = []
self.headers = headers
if header_dirs is None:
header_dirs = []
self.header_dirs = header_dirs
if libraries is None:
libraries = []
self.libraries = libraries
if lib_dirs is None:
lib_dirs = []
self.lib_dirs = lib_dirs
self.extra_support_code = extra_support_code
self._fn = None
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
...@@ -633,10 +670,21 @@ class CDataType(Type): ...@@ -633,10 +670,21 @@ class CDataType(Type):
return hash((type(self), self.ctype, self.freefunc)) return hash((type(self), self.ctype, self.freefunc))
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if data is not None and not isinstance(data, self._cdata_type): if data is not None and not isinstance(data, _cdata_type):
raise TypeError("expected None or PyCObject/PyCapsule") raise TypeError("expected None or a PyCapsule")
return data return data
def _get_func(self):
from theano.scalar import get_scalar_type
if self._fn is None:
v = get_scalar_type('int64')()
self._fn = theano.function([v], _make_cdata(self)(v), profile=False)
return self._fn
def make_value(self, ptr):
return self._get_func()(ptr)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ return """
%(ctype)s %(name)s; %(ctype)s %(name)s;
...@@ -646,29 +694,20 @@ class CDataType(Type): ...@@ -646,29 +694,20 @@ class CDataType(Type):
return "%(name)s = NULL;" % dict(name=name) return "%(name)s = NULL;" % dict(name=name)
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
if PY3: return """
s = """
%(name)s = (%(ctype)s)PyCapsule_GetPointer(py_%(name)s, NULL); %(name)s = (%(ctype)s)PyCapsule_GetPointer(py_%(name)s, NULL);
if (%(name)s == NULL) %(fail)s if (%(name)s == NULL) %(fail)s
""" """ % dict(name=name, ctype=self.ctype, fail=sub['fail'])
else:
s = """
%(name)s = (%(ctype)s)PyCObject_AsVoidPtr(py_%(name)s);
"""
return s % dict(name=name, ctype=self.ctype, fail=sub['fail'])
def c_support_code(self): def c_support_code(self):
if PY3: return """
return """ void _capsule_destructor(PyObject *o) {
void _py3_destructor(PyObject *o) {
void *d = PyCapsule_GetContext(o); void *d = PyCapsule_GetContext(o);
void *p = PyCapsule_GetPointer(o, NULL); void *p = PyCapsule_GetPointer(o, NULL);
void (*f)(void *) = (void (*)(void *))d; void (*f)(void *) = (void (*)(void *))d;
if (f != NULL) f(p); if (f != NULL) f(p);
} }
""" """ + self.extra_support_code
else:
return ""
def c_sync(self, name, sub): def c_sync(self, name, sub):
freefunc = self.freefunc freefunc = self.freefunc
...@@ -679,11 +718,9 @@ Py_XDECREF(py_%(name)s); ...@@ -679,11 +718,9 @@ Py_XDECREF(py_%(name)s);
if (%(name)s == NULL) { if (%(name)s == NULL) {
py_%(name)s = Py_None; py_%(name)s = Py_None;
Py_INCREF(py_%(name)s); Py_INCREF(py_%(name)s);
} else """ } else {
if PY3:
s += """{
py_%(name)s = PyCapsule_New((void *)%(name)s, NULL, py_%(name)s = PyCapsule_New((void *)%(name)s, NULL,
_py3_destructor); _capsule_destructor);
if (py_%(name)s != NULL) { if (py_%(name)s != NULL) {
if (PyCapsule_SetContext(py_%(name)s, (void *)%(freefunc)s) != 0) { if (PyCapsule_SetContext(py_%(name)s, (void *)%(freefunc)s) != 0) {
/* This won't trigger a call to freefunc since it could not be /* This won't trigger a call to freefunc since it could not be
...@@ -693,11 +730,6 @@ if (%(name)s == NULL) { ...@@ -693,11 +730,6 @@ if (%(name)s == NULL) {
py_%(name)s = NULL; py_%(name)s = NULL;
} }
} }
}"""
else:
s += """{
py_%(name)s = PyCObject_FromVoidPtr((void *)%(name)s,
(void (*)(void *))%(freefunc)s);
}""" }"""
if self.freefunc is not None: if self.freefunc is not None:
s += """ s += """
...@@ -710,8 +742,20 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); } ...@@ -710,8 +742,20 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
# free the data for us when released. # free the data for us when released.
return "" return ""
def c_headers(self):
return self.headers
def c_header_dirs(self):
return self.header_dirs
def c_libraries(self):
return self.libraries
def c_lib_dirs(self):
return self.lib_dirs
def c_code_cache_version(self): def c_code_cache_version(self):
return (2, self.ctype, self.freefunc) return (3,)
def __str__(self): def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.ctype) return "%s{%s}" % (self.__class__.__name__, self.ctype)
...@@ -727,4 +771,5 @@ class CDataTypeConstant(graph.Constant): ...@@ -727,4 +771,5 @@ class CDataTypeConstant(graph.Constant):
# There is no way to put the data in the signature, so we # There is no way to put the data in the signature, so we
# don't even try # don't even try
return (self.type,) return (self.type,)
CDataType.Constant = CDataTypeConstant CDataType.Constant = CDataTypeConstant
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论