提交 45cb07f0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1542 from lamblin/fix_random_shape

Fix bug with mixed types in shape of random sample
......@@ -301,14 +301,12 @@ def _infer_ndim_bcast(ndim, shape, *args):
else:
args_ndim = 0
# there is a convention that -1 means the corresponding shape of a
# potentially-broadcasted symbolic arg
if (isinstance(shape, (tuple, list))
and numpy.all(numpy.asarray(shape) >= 0)):
bcast = [(s == 1) for s in shape]
v_shape = tensor.TensorConstant(type=tensor.lvector,
data=theano._asarray(shape,
dtype='int64'))
if isinstance(shape, (tuple, list)):
# there is a convention that -1 means the corresponding shape of a
# potentially-broadcasted symbolic arg
#
# This case combines together symbolic and non-symbolic shape
# information
shape_ndim = len(shape)
if ndim is None:
ndim = shape_ndim
......@@ -317,18 +315,7 @@ def _infer_ndim_bcast(ndim, shape, *args):
raise ValueError('ndim should be equal to len(shape), but\n',
'ndim = %s, len(shape) = %s, shape = %s'
% (ndim, shape_ndim, shape))
elif isinstance(shape, (tuple, list)):
# there is a convention that -1 means the corresponding shape of a
# potentially-broadcasted symbolic arg
#
# This case combines together symbolic and non-symbolic shape
# information
if ndim is None:
ndim = args_ndim
else:
ndim = max(args_ndim, ndim)
ndim = max(args_ndim, len(shape))
shape = [-1] * (ndim - len(shape)) + list(shape)
bcast = []
pre_v_shape = []
for i, s in enumerate(shape):
......
......@@ -545,6 +545,42 @@ class T_random_function(utt.InferShapeTester):
self.assertRaises(ValueError, f, rng_state0, [4])
self.assertRaises(ValueError, f, rng_state0, [4, 3, 4, 5])
def test_mixed_shape(self):
# Test when the provided shape is a tuple of ints and scalar vars
rng_R = random_state_type()
shape0 = tensor.lscalar()
shape = (shape0, 3)
post_r, u = uniform(rng_R, size=shape, ndim=2)
f = compile.function([rng_R, shape0], u)
rng_state0 = numpy.random.RandomState(utt.fetch_seed())
assert f(rng_state0, 2).shape == (2, 3)
assert f(rng_state0, 8).shape == (8, 3)
post_r, v = uniform(rng_R, size=shape)
g = compile.function([rng_R, shape0], v)
assert g(rng_state0, 2).shape == (2, 3)
assert g(rng_state0, 8).shape == (8, 3)
def test_mixed_shape_bcastable(self):
# Test when the provided shape is a tuple of ints and scalar vars
rng_R = random_state_type()
shape0 = tensor.lscalar()
shape = (shape0, 1)
post_r, u = uniform(rng_R, size=shape, ndim=2)
assert u.broadcastable == (False, True)
f = compile.function([rng_R, shape0], u)
rng_state0 = numpy.random.RandomState(utt.fetch_seed())
assert f(rng_state0, 2).shape == (2, 1)
assert f(rng_state0, 8).shape == (8, 1)
post_r, v = uniform(rng_R, size=shape)
assert v.broadcastable == (False, True)
g = compile.function([rng_R, shape0], v)
assert g(rng_state0, 2).shape == (2, 1)
assert g(rng_state0, 8).shape == (8, 1)
def test_default_shape(self):
rng_R = random_state_type()
post_r, out = uniform(rng_R)
......
......@@ -333,6 +333,36 @@ class T_SharedRandomStreams(unittest.TestCase):
self.assertRaises(ValueError, f, [4])
self.assertRaises(ValueError, f, [4,3,4,5])
def test_mixed_shape(self):
# Test when the provided shape is a tuple of ints and scalar vars
random = RandomStreams(utt.fetch_seed())
shape0 = tensor.lscalar()
shape = (shape0, 3)
f = function([shape0], random.uniform(size=shape, ndim=2))
assert f(2).shape == (2, 3)
assert f(8).shape == (8, 3)
g = function([shape0], random.uniform(size=shape))
assert g(2).shape == (2, 3)
assert g(8).shape == (8, 3)
def test_mixed_shape_bcastable(self):
# Test when the provided shape is a tuple of ints and scalar vars
random = RandomStreams(utt.fetch_seed())
shape0 = tensor.lscalar()
shape = (shape0, 1)
u = random.uniform(size=shape, ndim=2)
assert u.broadcastable == (False, True)
f = function([shape0], u)
assert f(2).shape == (2, 1)
assert f(8).shape == (8, 1)
v = random.uniform(size=shape)
assert v.broadcastable == (False, True)
g = function([shape0], v)
assert g(2).shape == (2, 1)
assert g(8).shape == (8, 1)
def test_default_shape(self):
random = RandomStreams(utt.fetch_seed())
f = function([], random.uniform())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论