提交 33adfef5 authored 作者: Frederic Bastien's avatar Frederic Bastien

added new theano.misc.may_share_memory that handle more case and can be called…

added new theano.misc.may_share_memory that handle more case and can be called in all case(with and without scipy,...). Test it.
上级 49ab5ac2
"""
Helper function to detect memory sharing for ndarray AND sparse type.
numpy version support only ndarray.
"""
__docformat__ = "restructuredtext en"
import numpy
from theano.tensor.basic import TensorType
try:
import scipy.sparse
except ImportError:
#scipy not imported, their can be only ndarray
def may_share_memory(a, b, raise_other_type=True):
if not isinstance(a, numpy.ndarray) or not isinstance(b, numpy.ndarray):
if raise_other_type:
raise TypeError("may_share_memory support only ndarray when scipy is not available")
return False
return numpy.may_share_memory(a,b)
else:
#scipy imported, their can be ndarray and sparse type
from theano.sparse.basic import _is_sparse, SparseType
def may_share_memory(a, b, raise_other_type=True):
a_ndarray = isinstance(a, numpy.ndarray)
b_ndarray = isinstance(b, numpy.ndarray)
try:
a_sparse = _is_sparse(a)
b_sparse = _is_sparse(b)
except NotImplementedError:
if raise_other_type:
raise TypeError("may_share_memory support only ndarray and scipy.sparse type")
return False
if not(a_ndarray or a_sparse) or not(b_ndarray or b_sparse):
if raise_other_type:
raise TypeError("may_share_memory support only ndarray and scipy.sparse type")
return False
if a_ndarray and b_ndarray:
return TensorType.may_share_memory(a,b)
return SparseType.may_share_memory(a,b)
import numpy
import theano
try:
import scipy.sparse
scipy_imported = True
except ImportError:
scipy_imported = False
from theano.misc.may_share_memory import may_share_memory
def test_may_share_memory():
a=numpy.random.rand(5,4)
b=numpy.random.rand(5,4)
as_ar = lambda a: theano._asarray(a, dtype='int32')
for a_,b_,rep in [(a,a,True),(b,b,True),(a,b,False),
(a,a[0],True),(a,a[:,0],True),
(a,(0,),False),(a,1,False),
]:
assert may_share_memory(a_,b_,False)==rep
if rep == False:
try:
may_share_memory(a_,b_)
raise Exception("An error was expected")
except:
pass
if scipy_imported:
def test_may_share_memory_scipy():
a=scipy.sparse.csc_matrix(scipy.sparse.eye(5,3))
b=scipy.sparse.csc_matrix(scipy.sparse.eye(4,3))
as_ar = lambda a: theano._asarray(a, dtype='int32')
for a_,b_,rep in [(a,a,True),(b,b,True),(a,b,False),
(a,a.data,True),(a,a.indptr,True),(a,a.indices,True),(a,as_ar(a.shape),False),
(a.data,a,True),(a.indptr,a,True),(a.indices,a,True),(as_ar(a.shape),a,False),
(b,b.data,True),(b,b.indptr,True),(b,b.indices,True),(b,as_ar(b.shape),False),
(b.data,b,True),(b.indptr,b,True),(b.indices,b,True),(as_ar(b.shape),b,False),
(b.data,a,False),(b.indptr,a,False),(b.indices,a,False),(as_ar(b.shape),a,False),
]:
assert may_share_memory(a_,b_)==rep
if rep == False:
try:
may_share_memory(a_,b_)
raise Exception("An error was expected")
except:
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论