提交 614ec6d5 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix may_share_memory for sparse type as it was done for ndarray type. Add…

fix may_share_memory for sparse type as it was done for ndarray type. Add may_share_memory test for CudaNdarray type
上级 f295c786
......@@ -37,17 +37,18 @@ def test_may_share_memory():
#test that it raise error when needed.
for a_,b_,rep in [(a,(0,),False),(a,1,False),(a,None,False),]:
if rep == False:
try:
may_share_memory(a_,b_)
raise Exception("An error was expected")
except TypeError:
pass
try:
may_share_memory(b_,a_)
raise Exception("An error was expected")
except TypeError:
pass
assert may_share_memory(a_,b_,False)==rep
assert may_share_memory(b_,a_,False)==rep
try:
may_share_memory(a_,b_)
raise Exception("An error was expected")
except TypeError:
pass
try:
may_share_memory(b_,a_)
raise Exception("An error was expected")
except TypeError:
pass
if scipy_imported:
def test_may_share_memory_scipy():
......@@ -64,14 +65,18 @@ if scipy_imported:
assert may_share_memory(a_,b_)==rep
assert may_share_memory(b_,a_)==rep
if rep == False:
try:
may_share_memory(a_,b_)
raise Exception("An error was expected")
except:
pass
try:
may_share_memory(b_,a_)
raise Exception("An error was expected")
except:
pass
#test that it raise error when needed.
for a_,b_,rep in [(a,(0,),False),(a,1,False),(a,None,False)]:
assert may_share_memory(a_,b_,False)==rep
assert may_share_memory(b_,a_,False)==rep
try:
may_share_memory(a_,b_)
raise Exception("An error was expected")
except TypeError:
pass
try:
may_share_memory(b_,a_)
raise Exception("An error was expected")
except TypeError:
pass
......@@ -65,6 +65,48 @@ def test_softmax_optimizations():
assert env.outputs[0].owner.inputs[0].owner.op == cuda.host_from_gpu
assert env.outputs[0].owner.inputs[0].owner.inputs[0].owner.op == cuda.nnet.gpu_crossentropy_softmax_argmax_1hot_with_bias
def test_may_share_memory_cuda():
from theano.misc.may_share_memory import may_share_memory
a = cuda.CudaNdarray(numpy.zeros((3,4),dtype='float32'))
b = cuda.CudaNdarray(numpy.zeros((3,4),dtype='float32'))
na = numpy.zeros((3,4))
nb = numpy.zeros((3,4))
va = a.view()
vb = b.view()
ra = a.reshape((4,3))
rb = b.reshape((4,3))
#can't test the transpose as ta._strides = is not implemented
#manual transpose of a
#ta = a.reshape((4,3))
#ta._strides = (ta._strides[1],ta._strides[0])#not implemented
#elem_size=elem_size = numpy.zeros(0,dtype=a.dtype).dtype.itemsize
#ta.gpudata += ta.size*elem_size
for a_,b_,rep in [(a,a,True),(b,b,True),(a,b,False),
(a,na,False),(b,nb,False),(na,b,False),(nb,a,False),
(a,va,True),(b,vb,True),(va,b,False),(a,vb,False),
(a,ra,True),(b,rb,True),(ra,b,False),(a,rb,False),
]:
assert may_share_memory(a_,b_)==rep
assert may_share_memory(b_,a_)==rep
#test that it raise error when needed.
for a_,b_,rep in [(a,(0,),False),(a,1,False),(a,None,False)]:
assert may_share_memory(a_,b_,False)==rep
assert may_share_memory(b_,a_,False)==rep
try:
may_share_memory(a_,b_)
raise Exception("An error was expected")
except TypeError:
pass
try:
may_share_memory(b_,a_)
raise Exception("An error was expected")
except TypeError:
pass
def test_grad_sqrt_sum():
"""
This trigered a bug in the past.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论