提交 1dd5f5fc authored 作者: Frederic's avatar Frederic

Remove duplicate infer_shape tests and add a bad shape test for vector.

上级 e686253f
......@@ -4547,22 +4547,26 @@ class T_reshape(unittest.TestCase):
assert numpy.all(f_sub(a_val, b_val) == [2, 3])
def test_infer_shape(self):
def test_bad_shape(self):
a = matrix('a')
shapes = ivector('shapes')
ndim = 2
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.uniform(size=(3, 4)).astype(config.floatX)
r = a.reshape(shapes, ndim=2)
#Test reshape to 1 dim
r = a.reshape(shapes, ndim=1)
z = zeros_like(r)
f = function([a, shapes], z.shape)
self.assertRaises(ValueError, f, a_val, [13])
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.uniform(size=(3, 4)).astype(config.floatX)
#Test reshape to 1 dim
r = a.reshape(shapes, ndim=2)
z = zeros_like(r)
f = function([a, shapes], z.shape)
self.assertTrue((f(a_val, [4, 3]) == [4, 3]).all())
self.assertTrue((f(a_val, [-1, 3]) == [4, 3]).all())
self.assertTrue((f(a_val, [4, -1]) == [4, 3]).all())
self.assertRaises(ValueError, f, a_val, [-1, 5])
self.assertRaises(ValueError, f, a_val, [7, -1])
self.assertRaises(ValueError, f, a_val, [7, 5])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论