提交 0a3f92f8 authored 作者: Frederic Bastien's avatar Frederic Bastien

implement may_share_memory for CudaNdarray to remove false warning in debugmode.

上级 90789025
"""
Helper function to detect memory sharing for ndarray AND sparse type.
Function to detect memory sharing for ndarray AND sparse type AND CudaNdarray.
numpy version support only ndarray.
"""
......@@ -26,17 +26,30 @@ else:
b_ndarray = isinstance(b, numpy.ndarray)
try:
a_sparse = _is_sparse(a)
except NotImplementedError:
a_sparse = False
try:
b_sparse = _is_sparse(b)
except NotImplementedError:
if raise_other_type:
raise TypeError("may_share_memory support only ndarray and scipy.sparse type")
return False
b_sparse = False
a_cuda=False
b_cuda=False
if a.__class__.__name__ == "CudaNdarray":
a_cuda = True
if b.__class__.__name__ == "CudaNdarray":
b_cuda = True
if not(a_ndarray or a_sparse) or not(b_ndarray or b_sparse):
if not(a_ndarray or a_sparse or a_cuda) or not(b_ndarray or b_sparse or b_cuda):
if raise_other_type:
raise TypeError("may_share_memory support only ndarray and scipy.sparse type")
raise TypeError("may_share_memory support only ndarray and scipy.sparse and CudaNdarray type")
return False
if a_ndarray and b_ndarray:
return TensorType.may_share_memory(a,b)
if a_cuda and b_cuda:
from theano.sandbox.cuda.type import CudaNdarrayType
return CudaNdarrayType.may_share_memory(a,b)
if a_cuda or b_cuda:
return False
return SparseType.may_share_memory(a,b)
......@@ -87,6 +87,36 @@ class CudaNdarrayType(Type):
% (self, self.dtype, data, converted_data, self.dtype),
data)
@staticmethod
def bound(a):
high = a.gpudata
low = a.gpudata
#stride is in the number of element.
#we must convert that to bytes in case we
#will view the element as a different type.
elem_size = numpy.zeros(0,dtype=a.dtype).dtype.itemsize
for stri, shp in zip(a._strides,a.shape):
if stri<0:
low += (stri*elem_size)*(shp-1)
else:
high += (stri*elem_size)*(shp-1)
return low, high
@staticmethod
def may_share_memory(a,b):
#when this is called with a an ndarray and b
#a sparce matrix, numpy.may_share_memory fail.
if a is b:
return True
if a.__class__ is b.__class__:
a_l, a_h = CudaNdarrayType.bound(a)
b_l, b_h = CudaNdarrayType.bound(b)
if b_l>=a_h or a_l >= b_h:
return False
return True
else: return False
@staticmethod
def values_eq(a, b):
#TODO: make the comparaison without transfert.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论