提交 b84a6506 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Correctly cast random values to floatX in some tests.

上级 bb3460fe
......@@ -2280,12 +2280,13 @@ class T_Join_and_Split(unittest.TestCase):
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
b_val = rng.rand(1, 3, 1)
a_val = rng.rand(1, 4, 1).astype(config.floatX)
b_val = rng.rand(1, 3, 1).astype(config.floatX)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(1,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if dimension 0 does not match
self.assertRaises(ValueError, f, rng.rand(2,4,1), b_val)
bad_a_val = rng.rand(2, 4, 1).astype(config.floatX)
self.assertRaises(ValueError, f, bad_a_val, b_val)
def test_broadcastable_flag_assignment_mixed_thisaxes(self):
"""
......@@ -2300,12 +2301,13 @@ class T_Join_and_Split(unittest.TestCase):
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(2, 4, 1)
b_val = rng.rand(1, 4, 1)
a_val = rng.rand(2, 4, 1).astype(config.floatX)
b_val = rng.rand(1, 4, 1).astype(config.floatX)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if b_val.shape[0] is not 1
self.assertRaises(TypeError, f, a_val, rng.rand(3,4,1))
bad_b_val = rng.rand(3, 4, 1).astype(config.floatX)
self.assertRaises(TypeError, f, a_val, bad_b_val)
def test_broadcastable_flags_all_broadcastable_on_joinaxis(self):
"""
......@@ -2320,13 +2322,15 @@ class T_Join_and_Split(unittest.TestCase):
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
b_val = rng.rand(1, 4, 1)
a_val = rng.rand(1, 4, 1).astype(config.floatX)
b_val = rng.rand(1, 4, 1).astype(config.floatX)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, f, rng.rand(2,4,1), b_val)
self.assertRaises(TypeError, f, a_val, rng.rand(3,4,1))
bad_a_val = rng.rand(2, 4, 1).astype(config.floatX)
bad_b_val = rng.rand(3, 4, 1).astype(config.floatX)
self.assertRaises(TypeError, f, bad_a_val, b_val)
self.assertRaises(TypeError, f, a_val, bad_b_val)
def test_broadcastable_single_input_broadcastable_dimension(self):
"""
......@@ -2341,11 +2345,12 @@ class T_Join_and_Split(unittest.TestCase):
f = function([a], b)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
a_val = rng.rand(1, 4, 1).astype(config.floatX)
f(a_val)
utt.verify_grad((lambda a: join(0,a)), [a_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, f, rng.rand(2,4,1))
bad_a_val = rng.rand(2, 4, 1).astype(config.floatX)
self.assertRaises(TypeError, f, bad_a_val)
def test_broadcastable_flags_many_dims_and_inputs(self):
"""
......@@ -2369,26 +2374,32 @@ class T_Join_and_Split(unittest.TestCase):
g = function([a,b,c,d,e], f)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 1, 1, 1, 2, 1)
b_val = rng.rand(1, 1, 1, 1, 2, 1)
c_val = rng.rand(1, 1, 1, 1, 2, 1)
d_val = rng.rand(1, 1, 1, 1, 2, 1)
e_val = rng.rand(1, 1, 1, 1, 2, 1)
a_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
b_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
c_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
d_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
e_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
g(a_val, b_val, c_val, d_val, e_val)
utt.verify_grad((lambda a,b,c,d,e: join(0,a,b,c,d,e)),
[a_val, b_val, c_val, d_val, e_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, g, rng.rand(2,1,1,1,2,1), b_val, c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, rng.rand(2,1,1,1,2,1), c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, rng.rand(2,1,1,1,2,1), d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, rng.rand(2,1,1,1,2,1), e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, d_val, rng.rand(2,1,1,1,2,1))
bad_val = rng.rand(2, 1, 1, 1, 2, 1).astype(config.floatX)
self.assertRaises(TypeError, g, bad_val, b_val, c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, bad_val, c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, bad_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, bad_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, d_val, bad_val)
# Should raise an error if any dimension other than 4 has length != 1
self.assertRaises(ValueError, g, rng.rand(1,2,1,1,2,1), b_val, c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, rng.rand(1,1,1,1,2,2), c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, rng.rand(1,1,2,1,2,1), d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, rng.rand(1,2,1,1,2,1), e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, d_val, rng.rand(1,1,1,2,2,1))
bad_a_val = rng.rand(1, 2, 1, 1, 2, 1).astype(config.floatX)
bad_b_val = rng.rand(1, 1, 1, 1, 2, 2).astype(config.floatX)
bad_c_val = rng.rand(1, 1, 2, 1, 2, 1).astype(config.floatX)
bad_d_val = rng.rand(1, 2, 1, 1, 2, 1).astype(config.floatX)
bad_e_val = rng.rand(1, 1, 1, 2, 2, 1).astype(config.floatX)
self.assertRaises(ValueError, g, bad_a_val, b_val, c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, bad_b_val, c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, bad_c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, bad_d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, d_val, bad_e_val)
class test_comparison(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论