提交 b84a6506 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Correctly cast random values to floatX in some tests.

上级 bb3460fe
...@@ -2280,12 +2280,13 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2280,12 +2280,13 @@ class T_Join_and_Split(unittest.TestCase):
f = function([a,b], c) f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed()) rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1) a_val = rng.rand(1, 4, 1).astype(config.floatX)
b_val = rng.rand(1, 3, 1) b_val = rng.rand(1, 3, 1).astype(config.floatX)
f(a_val, b_val) f(a_val, b_val)
utt.verify_grad((lambda a,b: join(1,a,b)), [a_val, b_val], rng=rng) utt.verify_grad((lambda a,b: join(1,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if dimension 0 does not match # Should raise an error if dimension 0 does not match
self.assertRaises(ValueError, f, rng.rand(2,4,1), b_val) bad_a_val = rng.rand(2, 4, 1).astype(config.floatX)
self.assertRaises(ValueError, f, bad_a_val, b_val)
def test_broadcastable_flag_assignment_mixed_thisaxes(self): def test_broadcastable_flag_assignment_mixed_thisaxes(self):
""" """
...@@ -2300,12 +2301,13 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2300,12 +2301,13 @@ class T_Join_and_Split(unittest.TestCase):
f = function([a,b], c) f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed()) rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(2, 4, 1) a_val = rng.rand(2, 4, 1).astype(config.floatX)
b_val = rng.rand(1, 4, 1) b_val = rng.rand(1, 4, 1).astype(config.floatX)
f(a_val, b_val) f(a_val, b_val)
utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng) utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if b_val.shape[0] is not 1 # Should raise an error if b_val.shape[0] is not 1
self.assertRaises(TypeError, f, a_val, rng.rand(3,4,1)) bad_b_val = rng.rand(3, 4, 1).astype(config.floatX)
self.assertRaises(TypeError, f, a_val, bad_b_val)
def test_broadcastable_flags_all_broadcastable_on_joinaxis(self): def test_broadcastable_flags_all_broadcastable_on_joinaxis(self):
""" """
...@@ -2320,13 +2322,15 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2320,13 +2322,15 @@ class T_Join_and_Split(unittest.TestCase):
f = function([a,b], c) f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed()) rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1) a_val = rng.rand(1, 4, 1).astype(config.floatX)
b_val = rng.rand(1, 4, 1) b_val = rng.rand(1, 4, 1).astype(config.floatX)
f(a_val, b_val) f(a_val, b_val)
utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng) utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1 # Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, f, rng.rand(2,4,1), b_val) bad_a_val = rng.rand(2, 4, 1).astype(config.floatX)
self.assertRaises(TypeError, f, a_val, rng.rand(3,4,1)) bad_b_val = rng.rand(3, 4, 1).astype(config.floatX)
self.assertRaises(TypeError, f, bad_a_val, b_val)
self.assertRaises(TypeError, f, a_val, bad_b_val)
def test_broadcastable_single_input_broadcastable_dimension(self): def test_broadcastable_single_input_broadcastable_dimension(self):
""" """
...@@ -2341,11 +2345,12 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2341,11 +2345,12 @@ class T_Join_and_Split(unittest.TestCase):
f = function([a], b) f = function([a], b)
rng = numpy.random.RandomState(seed=utt.fetch_seed()) rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1) a_val = rng.rand(1, 4, 1).astype(config.floatX)
f(a_val) f(a_val)
utt.verify_grad((lambda a: join(0,a)), [a_val], rng=rng) utt.verify_grad((lambda a: join(0,a)), [a_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1 # Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, f, rng.rand(2,4,1)) bad_a_val = rng.rand(2, 4, 1).astype(config.floatX)
self.assertRaises(TypeError, f, bad_a_val)
def test_broadcastable_flags_many_dims_and_inputs(self): def test_broadcastable_flags_many_dims_and_inputs(self):
""" """
...@@ -2369,26 +2374,32 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2369,26 +2374,32 @@ class T_Join_and_Split(unittest.TestCase):
g = function([a,b,c,d,e], f) g = function([a,b,c,d,e], f)
rng = numpy.random.RandomState(seed=utt.fetch_seed()) rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 1, 1, 1, 2, 1) a_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
b_val = rng.rand(1, 1, 1, 1, 2, 1) b_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
c_val = rng.rand(1, 1, 1, 1, 2, 1) c_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
d_val = rng.rand(1, 1, 1, 1, 2, 1) d_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
e_val = rng.rand(1, 1, 1, 1, 2, 1) e_val = rng.rand(1, 1, 1, 1, 2, 1).astype(config.floatX)
g(a_val, b_val, c_val, d_val, e_val) g(a_val, b_val, c_val, d_val, e_val)
utt.verify_grad((lambda a,b,c,d,e: join(0,a,b,c,d,e)), utt.verify_grad((lambda a,b,c,d,e: join(0,a,b,c,d,e)),
[a_val, b_val, c_val, d_val, e_val], rng=rng) [a_val, b_val, c_val, d_val, e_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1 # Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, g, rng.rand(2,1,1,1,2,1), b_val, c_val, d_val, e_val) bad_val = rng.rand(2, 1, 1, 1, 2, 1).astype(config.floatX)
self.assertRaises(TypeError, g, a_val, rng.rand(2,1,1,1,2,1), c_val, d_val, e_val) self.assertRaises(TypeError, g, bad_val, b_val, c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, rng.rand(2,1,1,1,2,1), d_val, e_val) self.assertRaises(TypeError, g, a_val, bad_val, c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, rng.rand(2,1,1,1,2,1), e_val) self.assertRaises(TypeError, g, a_val, b_val, bad_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, d_val, rng.rand(2,1,1,1,2,1)) self.assertRaises(TypeError, g, a_val, b_val, c_val, bad_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, d_val, bad_val)
# Should raise an error if any dimension other than 4 has length != 1 # Should raise an error if any dimension other than 4 has length != 1
self.assertRaises(ValueError, g, rng.rand(1,2,1,1,2,1), b_val, c_val, d_val, e_val) bad_a_val = rng.rand(1, 2, 1, 1, 2, 1).astype(config.floatX)
self.assertRaises(ValueError, g, a_val, rng.rand(1,1,1,1,2,2), c_val, d_val, e_val) bad_b_val = rng.rand(1, 1, 1, 1, 2, 2).astype(config.floatX)
self.assertRaises(ValueError, g, a_val, b_val, rng.rand(1,1,2,1,2,1), d_val, e_val) bad_c_val = rng.rand(1, 1, 2, 1, 2, 1).astype(config.floatX)
self.assertRaises(ValueError, g, a_val, b_val, c_val, rng.rand(1,2,1,1,2,1), e_val) bad_d_val = rng.rand(1, 2, 1, 1, 2, 1).astype(config.floatX)
self.assertRaises(ValueError, g, a_val, b_val, c_val, d_val, rng.rand(1,1,1,2,2,1)) bad_e_val = rng.rand(1, 1, 1, 2, 2, 1).astype(config.floatX)
self.assertRaises(ValueError, g, bad_a_val, b_val, c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, bad_b_val, c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, bad_c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, bad_d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, d_val, bad_e_val)
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论