提交 9415e78e authored 作者: Ian Goodfellow's avatar Ian Goodfellow

completed ticket 618 (added tests of giving invalid shapes to cuda_ndarry.reshape)

fixed an issue where test_reshape in test_cuda_ndarray did not use the correct seed for random number generation
上级 ba320eeb
from theano.tests import unittest_tools as utt
import time, copy, sys import time, copy, sys
import theano import theano
import theano.sandbox.cuda as cuda_ndarray import theano.sandbox.cuda as cuda_ndarray
...@@ -207,12 +208,26 @@ def test_reshape(): ...@@ -207,12 +208,26 @@ def test_reshape():
((1,2,3), (6,)), ((1,2,3), (6,)),
((1,2,3,2), (6,2)), ((1,2,3,2), (6,2)),
((2,3,2), (6,2)), ((2,3,2), (6,2)),
((2,3,2), 12) ((2,3,2), (12,))
] ]
def subtest(shape_1, shape_2): bad_shapelist = [
((1,2,3), (1,2,4)),
((1,), (2,)),
((1,2,3), (2,2,1)),
((1,2,3), (5,)),
((1,2,3,2), (6,3)),
((2,3,2), (5,2)),
((2,3,2), (11,))
]
utt.seed_rng()
rng = numpy.random.RandomState(utt.fetch_seed())
def subtest(shape_1, shape_2, rng):
#print >> sys.stdout, "INFO: shapes", shape_1, shape_2 #print >> sys.stdout, "INFO: shapes", shape_1, shape_2
a = theano._asarray(numpy.random.random(shape_1), dtype='float32') a = theano._asarray(rng.randn(*shape_1), dtype='float32')
b = cuda_ndarray.CudaNdarray(a) b = cuda_ndarray.CudaNdarray(a)
aa = a.reshape(shape_2) aa = a.reshape(shape_2)
...@@ -224,13 +239,27 @@ def test_reshape(): ...@@ -224,13 +239,27 @@ def test_reshape():
assert numpy.all(aa == n_bb) assert numpy.all(aa == n_bb)
def bad_subtest(shape_1, shape_2, rng):
a = theano._asarray(rng.randn(*shape_1), dtype='float32')
b = cuda_ndarray.CudaNdarray(a)
try:
bb = b.reshape(shape_2)
except Exception, ValueError:
return
assert False
# test working shapes # test working shapes
for shape_1, shape_2 in shapelist: for shape_1, shape_2 in shapelist:
subtest(shape_1, shape_2) subtest(shape_1, shape_2, rng)
subtest(shape_2, shape_1) subtest(shape_2, shape_1, rng)
##TODO: see ticket #618 # test shape combinations that should give error
#print >> sys.stderr, "WARN: TODO: test shape combinations that should give error" for shape_1, shape_2 in bad_shapelist:
bad_subtest(shape_1, shape_2, rng)
bad_subtest(shape_2, shape_1, rng)
def test_getshape(): def test_getshape():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论