提交 df819993 authored 作者: Frederic's avatar Frederic

speed up misc.may_share_memory()

上级 33de45ca
......@@ -15,12 +15,14 @@ try:
def _is_sparse(a):
return scipy.sparse.issparse(a)
except ImportError:
#scipy not imported, their can be only ndarray and cudandarray
# scipy not imported, their can be only ndarray and cudandarray
def _is_sparse(a):
return False
from theano.sandbox import cuda
if cuda.cuda_available:
from theano.sandbox.cuda.type import CudaNdarrayType
def _is_cuda(a):
return isinstance(a, cuda.CudaNdarray)
else:
......@@ -40,13 +42,19 @@ else:
def may_share_memory(a, b, raise_other_type=True):
a_ndarray = isinstance(a, numpy.ndarray)
b_ndarray = isinstance(b, numpy.ndarray)
a_sparse = _is_sparse(a)
b_sparse = _is_sparse(b)
if a_ndarray and b_ndarray:
return TensorType.may_share_memory(a, b)
a_cuda = _is_cuda(a)
b_cuda = _is_cuda(b)
if a_cuda and b_cuda:
return CudaNdarrayType.may_share_memory(a, b)
a_gpua = _is_gpua(a)
b_gpua = _is_gpua(b)
if a_gpua and b_gpua:
return gpuarray.pygpu.gpuarray.may_share_memory(a, b)
a_sparse = _is_sparse(a)
b_sparse = _is_sparse(b)
if (not(a_ndarray or a_sparse or a_cuda or a_gpua) or
not(b_ndarray or b_sparse or b_cuda or b_gpua)):
if raise_other_type:
......@@ -54,13 +62,6 @@ def may_share_memory(a, b, raise_other_type=True):
" and scipy.sparse, CudaNdarray or GpuArray type")
return False
if a_ndarray and b_ndarray:
return TensorType.may_share_memory(a, b)
if a_cuda and b_cuda:
from theano.sandbox.cuda.type import CudaNdarrayType
return CudaNdarrayType.may_share_memory(a, b)
if a_gpua and b_gpua:
return gpuarray.pygpu.gpuarray.may_share_memory(a, b)
if a_cuda or b_cuda or a_gpua or b_gpua:
return False
return SparseType.may_share_memory(a, b)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论