提交 41976981 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update tests of Join with broadcastable dimensions.

- Compile a function, and call it on random values (so DebugMode can verify things) - Work on floatX, so gradients can be verified - Verify gradients - Check that some incorrect inputs lead to errors
上级 45a8228b
......@@ -2098,6 +2098,7 @@ class T_Join_and_Split(unittest.TestCase):
"""
def setUp(self):
Join.debug = False
utt.seed_rng()
def test_join_scalar(self):
a = as_tensor_variable(1)
......@@ -2266,54 +2267,91 @@ class T_Join_and_Split(unittest.TestCase):
a join operation on non-join axes are True if one or
more inputs is broadcastable on that dimension.
"""
a = TensorType(dtype='int8', broadcastable=[0, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])()
a = TensorType(dtype=config.floatX, broadcastable=[0, 0, 1])()
b = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
c = join(1, a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
b_val = rng.rand(1, 3, 1)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(1,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if dimension 0 does not match
self.assertRaises(ValueError, f, rng.rand(2,4,1), b_val)
def test_broadcastable_flag_assignment_mixed_thisaxes(self):
"""
Test that the broadcastable flag of the join axis
is False when some inputs are broadcastable on that
dimension.
"""
a = TensorType(dtype='int8', broadcastable=[0, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])()
a = TensorType(dtype=config.floatX, broadcastable=[0, 0, 1])()
b = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
c = join(0, a, b)
assert not c.type.broadcastable[0]
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(2, 4, 1)
b_val = rng.rand(1, 4, 1)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if b_val.shape[0] is not 1
self.assertRaises(TypeError, f, a_val, rng.rand(3,4,1))
def test_broadcastable_flags_all_broadcastable_on_joinaxis(self):
"""
Test that joining together several inputs which are all
broadcastable on the join dimension results in the output
being non-broadcastable on the join dimension.
"""
a = TensorType(dtype='int8', broadcastable=[1, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])()
a = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
b = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
c = join(0, a, b)
assert not c.type.broadcastable[0]
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
b_val = rng.rand(1, 4, 1)
f(a_val, b_val)
utt.verify_grad((lambda a,b: join(0,a,b)), [a_val, b_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, f, rng.rand(2,4,1), b_val)
self.assertRaises(TypeError, f, a_val, rng.rand(3,4,1))
def test_broadcastable_single_input_broadcastable_dimension(self):
"""
Test that all broadcastable flags are preserved by a
single-input join.
"""
a = join(0, TensorType(dtype='int8', broadcastable=[1, 0, 1])())
assert a.type.broadcastable[0]
assert a.type.broadcastable[2]
assert not a.type.broadcastable[1]
a = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1])()
b = join(0, a)
assert b.type.broadcastable[0]
assert b.type.broadcastable[2]
assert not b.type.broadcastable[1]
f = function([a], b)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1)
f(a_val)
utt.verify_grad((lambda a: join(0,a)), [a_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, f, rng.rand(2,4,1))
def test_broadcastable_flags_many_dims_and_inputs(self):
"""
Test that the right broadcastable flags get set for a join
with many inputs and many input dimensions.
"""
a = TensorType(dtype='int8', broadcastable=[1, 0, 1, 0, 0, 0])()
b = TensorType(dtype='int8', broadcastable=[1, 1, 1, 0, 0, 0])()
c = TensorType(dtype='int8', broadcastable=[1, 0, 0, 0, 0, 0])()
d = TensorType(dtype='int8', broadcastable=[1, 0, 1, 1, 0, 1])()
e = TensorType(dtype='int8', broadcastable=[1, 0, 1, 0, 0, 1])()
a = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1, 0, 0, 0])()
b = TensorType(dtype=config.floatX, broadcastable=[1, 1, 1, 0, 0, 0])()
c = TensorType(dtype=config.floatX, broadcastable=[1, 0, 0, 0, 0, 0])()
d = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1, 1, 0, 1])()
e = TensorType(dtype=config.floatX, broadcastable=[1, 0, 1, 0, 0, 1])()
f = join(0, a, b, c, d, e)
fb = f.type.broadcastable
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
......@@ -2324,6 +2362,30 @@ class T_Join_and_Split(unittest.TestCase):
hb = h.type.broadcastable
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
g = function([a,b,c,d,e], f)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 1, 1, 1, 2, 1)
b_val = rng.rand(1, 1, 1, 1, 2, 1)
c_val = rng.rand(1, 1, 1, 1, 2, 1)
d_val = rng.rand(1, 1, 1, 1, 2, 1)
e_val = rng.rand(1, 1, 1, 1, 2, 1)
g(a_val, b_val, c_val, d_val, e_val)
utt.verify_grad((lambda a,b,c,d,e: join(0,a,b,c,d,e)),
[a_val, b_val, c_val, d_val, e_val], rng=rng)
# Should raise an error if length of dimension 0 is not 1
self.assertRaises(TypeError, g, rng.rand(2,1,1,1,2,1), b_val, c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, rng.rand(2,1,1,1,2,1), c_val, d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, rng.rand(2,1,1,1,2,1), d_val, e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, rng.rand(2,1,1,1,2,1), e_val)
self.assertRaises(TypeError, g, a_val, b_val, c_val, d_val, rng.rand(2,1,1,1,2,1))
# Should raise an error if any dimension other than 4 has length != 1
self.assertRaises(ValueError, g, rng.rand(1,2,1,1,2,1), b_val, c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, rng.rand(1,1,1,1,2,2), c_val, d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, rng.rand(1,1,2,1,2,1), d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, rng.rand(1,2,1,1,2,1), e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, d_val, rng.rand(1,1,1,2,2,1))
class test_comparison(unittest.TestCase):
def test_gt(self):
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论