提交 419768e7 authored 作者: Frederic's avatar Frederic

Allow permutation to return just one permutation.

fix gh-2158
上级 9da67d20
......@@ -183,7 +183,11 @@ class RandomFunction(gof.Op):
draw.
"""
shape = tensor.as_tensor_variable(shape, ndim=1)
shape_ = tensor.as_tensor_variable(shape, ndim=1)
if shape == ():
shape = shape_.astype('int32')
else:
shape = shape_
assert shape.type.ndim == 1
assert (shape.type.dtype == 'int64') or (shape.type.dtype == 'int32')
if not isinstance(r.type, RandomStateType):
......@@ -700,7 +704,15 @@ def permutation(random_state, size=None, n=1, ndim=None, dtype='int64'):
:note:
Note that the output will then be of dimension ndim+1.
"""
ndim, size, bcast = _infer_ndim_bcast(ndim, size)
if size is None or size == ():
if not(ndim is None or ndim == 1):
raise TypeError(
"You asked for just one permutation but asked for more then 1 dimensions.")
ndim = 1
size = ()
bcast = ()
else:
ndim, size, bcast = _infer_ndim_bcast(ndim, size)
#print "NDIM", ndim, size
op = RandomFunction(permutation_helper,
tensor.TensorType(dtype=dtype, broadcastable=bcast + (False,)),
......
......@@ -230,7 +230,7 @@ class T_random_function(utt.InferShapeTester):
rng_R = random_state_type()
# No shape, no args -> TypeError
self.assertRaises(TypeError, permutation, rng_R, size=None, ndim=2)
self.assertRaises(TypeError, poisson, rng_R, size=None, ndim=2)
def test_random_function_ndim_added(self):
"""Test that random_function helper function accepts ndim_added as
......@@ -561,6 +561,19 @@ class T_random_function(utt.InferShapeTester):
self.assertTrue(numpy.all(val0 == numpy_val0))
self.assertTrue(numpy.all(val1 == numpy_val1))
# Test that we can generate a list: have size=None or ().
for ndim in [1, None]:
post_r, out = permutation(rng_R, n=10, size=None, ndim=ndim)
inp = compile.In(rng_R,
value=numpy.random.RandomState(utt.fetch_seed()),
update=post_r, mutable=True)
f = theano.function([inp], out)
o = f()
assert o.shape == (10,)
assert (numpy.sort(o) == numpy.arange(10)).all()
# Wrong number of dimensions asked
self.assertRaises(TypeError, permutation, rng_R, size=None, ndim=2)
def test_multinomial(self):
"""Test that raw_random.multinomial generates the same
results as numpy."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论