提交 1dd5f5fc authored 作者: Frederic's avatar Frederic

Remove duplicate infer_shape tests and add a bad shape test for vector.

上级 e686253f
...@@ -4547,22 +4547,26 @@ class T_reshape(unittest.TestCase): ...@@ -4547,22 +4547,26 @@ class T_reshape(unittest.TestCase):
assert numpy.all(f_sub(a_val, b_val) == [2, 3]) assert numpy.all(f_sub(a_val, b_val) == [2, 3])
def test_infer_shape(self): def test_bad_shape(self):
a = matrix('a') a = matrix('a')
shapes = ivector('shapes') shapes = ivector('shapes')
ndim = 2 ndim = 2
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.uniform(size=(3, 4)).astype(config.floatX)
r = a.reshape(shapes, ndim=2) #Test reshape to 1 dim
r = a.reshape(shapes, ndim=1)
z = zeros_like(r) z = zeros_like(r)
f = function([a, shapes], z.shape) f = function([a, shapes], z.shape)
self.assertRaises(ValueError, f, a_val, [13])
rng = numpy.random.RandomState(seed=utt.fetch_seed()) #Test reshape to 1 dim
a_val = rng.uniform(size=(3, 4)).astype(config.floatX) r = a.reshape(shapes, ndim=2)
z = zeros_like(r)
f = function([a, shapes], z.shape)
self.assertTrue((f(a_val, [4, 3]) == [4, 3]).all())
self.assertTrue((f(a_val, [-1, 3]) == [4, 3]).all())
self.assertTrue((f(a_val, [4, -1]) == [4, 3]).all())
self.assertRaises(ValueError, f, a_val, [-1, 5]) self.assertRaises(ValueError, f, a_val, [-1, 5])
self.assertRaises(ValueError, f, a_val, [7, -1]) self.assertRaises(ValueError, f, a_val, [7, -1])
self.assertRaises(ValueError, f, a_val, [7, 5]) self.assertRaises(ValueError, f, a_val, [7, 5])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论