提交 41976981 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update tests of Join with broadcastable dimensions.

- Compile a function, and call it on random values (so DebugMode can verify things) - Work on floatX, so gradients can be verified - Verify gradients - Check that some incorrect inputs lead to errors
上级 45a8228b
...@@ -2098,6 +2098,7 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2098,6 +2098,7 @@ class T_Join_and_Split(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
Join.debug = False Join.debug = False
utt.seed_rng()
def test_join_scalar(self): def test_join_scalar(self):
a = as_tensor_variable(1) a = as_tensor_variable(1)
...@@ -2266,54 +2267,91 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2266,54 +2267,91 @@ class T_Join_and_Split(unittest.TestCase):
a join operation on non-join axes are True if one or a join operation on non-join axes are True if one or
more inputs is broadcastable on that dimension. more inputs is broadcastable on that dimension.
""" """
a = TensorType(dtype='int8', broadcastable=[0, 0, 1])() a = TensorType(dtype=config.floatX, broadcastable=[0, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])() b = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
c = join(1, a, b) c = join(1, a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2] assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1] assert not c.type.broadcastable[1]
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
b_val = rng.rand(1, 3, 1)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(1,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if dimension 0 does not match
self.assertRaises(ValueError, f, rng.rand(2,4,1), b_val)
def test_broadcastable_flag_assignment_mixed_thisaxes(self): def test_broadcastable_flag_assignment_mixed_thisaxes(self):
""" """
Test that the broadcastable flag of the join axis Test that the broadcastable flag of the join axis
is False when some inputs are broadcastable on that is False when some inputs are broadcastable on that
dimension. dimension.
""" """
a = TensorType(dtype='int8', broadcastable=[0, 0, 1])() a = TensorType(dtype=config.floatX, broadcastable=[0, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])() b = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
c = join(0, a, b) c = join(0, a, b)
assert not c.type.broadcastable[0] assert not c.type.broadcastable[0]
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(2, 4, 1)
b_val = rng.rand(1, 4, 1)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if b_val.shape[0] is not 1
self.assertRaises(TypeError, f, a_val, rng.rand(3,4,1))
def test_broadcastable_flags_all_broadcastable_on_joinaxis(self): def test_broadcastable_flags_all_broadcastable_on_joinaxis(self):
""" """
Test that joining together several inputs which are all Test that joining together several inputs which are all
broadcastable on the join dimension results in the output broadcastable on the join dimension results in the output
being non-broadcastable on the join dimension. being non-broadcastable on the join dimension.
""" """
a = TensorType(dtype='int8', broadcastable=[1, 0, 1])() a = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])() b = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
c = join(0, a, b) c = join(0, a, b)
assert not c.type.broadcastable[0] assert not c.type.broadcastable[0]
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
b_val = rng.rand(1, 4, 1)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, f, rng.rand(2,4,1), b_val)
self.assertRaises(TypeError, f, a_val, rng.rand(3,4,1))
def test_broadcastable_single_input_broadcastable_dimension(self): def test_broadcastable_single_input_broadcastable_dimension(self):
""" """
Test that all broadcastable flags are preserved by a Test that all broadcastable flags are preserved by a
single-input join. single-input join.
""" """
a = join(0, TensorType(dtype='int8', broadcastable=[1, 0, 1])()) a = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
assert a.type.broadcastable[0] b = join(0, a)
assert a.type.broadcastable[2] assert b.type.broadcastable[0]
assert not a.type.broadcastable[1] assert b.type.broadcastable[2]
assert not b.type.broadcastable[1]
f = function([a], b)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
f(a_val)
utt.verify_grad((lambda a: join(0,a)), [a_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, f, rng.rand(2,4,1))
def test_broadcastable_flags_many_dims_and_inputs(self): def test_broadcastable_flags_many_dims_and_inputs(self):
""" """
Test that the right broadcastable flags get set for a join Test that the right broadcastable flags get set for a join
with many inputs and many input dimensions. with many inputs and many input dimensions.
""" """
a = TensorType(dtype='int8', broadcastable=[1, 0, 1, 0, 0, 0])() a = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1, 0, 0, 0])()
b = TensorType(dtype='int8', broadcastable=[1, 1, 1, 0, 0, 0])() b = TensorType(dtype=config.floatX, broadcastable=[1, 1, 1, 0, 0, 0])()
c = TensorType(dtype='int8', broadcastable=[1, 0, 0, 0, 0, 0])() c = TensorType(dtype=config.floatX, broadcastable=[1, 0, 0, 0, 0, 0])()
d = TensorType(dtype='int8', broadcastable=[1, 0, 1, 1, 0, 1])() d = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1, 1, 0, 1])()
e = TensorType(dtype='int8', broadcastable=[1, 0, 1, 0, 0, 1])() e = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1, 0, 0, 1])()
f = join(0, a, b, c, d, e) f = join(0, a, b, c, d, e)
fb = f.type.broadcastable fb = f.type.broadcastable
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5] assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
...@@ -2324,6 +2362,30 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2324,6 +2362,30 @@ class T_Join_and_Split(unittest.TestCase):
hb = h.type.broadcastable hb = h.type.broadcastable
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5] assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
g = function([a,b,c,d,e], f)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 1, 1, 1, 2, 1)
b_val = rng.rand(1, 1, 1, 1, 2, 1)
c_val = rng.rand(1, 1, 1, 1, 2, 1)
d_val = rng.rand(1, 1, 1, 1, 2, 1)
e_val = rng.rand(1, 1, 1, 1, 2, 1)
g(a_val, b_val, c_val, d_val, e_val)
utt.verify_grad((lambda a,b,c,d,e: join(0,a,b,c,d,e)),
[a_val, b_val, c_val, d_val, e_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, g, rng.rand(2,1,1,1,2,1), b_val, c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, rng.rand(2,1,1,1,2,1), c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, rng.rand(2,1,1,1,2,1), d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, rng.rand(2,1,1,1,2,1), e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, d_val, rng.rand(2,1,1,1,2,1))
# Should raise an error if any dimension other than 4 has length != 1
self.assertRaises(ValueError, g, rng.rand(1,2,1,1,2,1), b_val, c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, rng.rand(1,1,1,1,2,2), c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, rng.rand(1,1,2,1,2,1), d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, rng.rand(1,2,1,1,2,1), e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, d_val, rng.rand(1,1,1,2,2,1))
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
def test_gt(self): def test_gt(self):
for dtype in ['float64', 'float32', 'complex64', 'complex128']: for dtype in ['float64', 'float32', 'complex64', 'complex128']:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论