提交 5420caac authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1722 from bouthilx/poisson

Added numpy wrapper for poisson distribution
...@@ -57,6 +57,28 @@ Reference ...@@ -57,6 +57,28 @@ Reference
dimensions, the first argument may be a plain integer dimensions, the first argument may be a plain integer
to supplement the missing information. to supplement the missing information.
.. method:: choice(self, size=(), a=2, replace=True, p=None, ndim=None, dtype='int64'):
Choose values from `a` with or without replacement. `a` can be a 1-D
array or a positive scalar. If `a` is a scalar, the samples are drawn
from the range 0,...,a-1.
If the size argument is ambiguous on the number of dimensions,
ndim may be a plain integer to supplement the missing
information.
.. method:: poisson(self, size=(), lam=None, ndim=None, dtype='int64'):
Usage: poisson(random_state, size, lam=5)
Draw samples from a Poisson distribution.
The Poisson distribution is the limit of the Binomial distribution for large N.
If the size argument is ambiguous on the number of dimensions,
ndim may be a plain integer to supplement the missing
information.
.. method:: permutation(self, size=(), n=1, ndim=None): .. method:: permutation(self, size=(), n=1, ndim=None):
Returns permutations of the integers between 0 and n-1, as many times Returns permutations of the integers between 0 and n-1, as many times
......
...@@ -622,6 +622,40 @@ def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None, ...@@ -622,6 +622,40 @@ def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None,
broadcastable=bcast)) broadcastable=bcast))
return op(random_state, size, a, replace, p) return op(random_state, size, a, replace, p)
def poisson_helper(random_state, lam, size):
"""
Helper function to draw random numbers using numpy's poisson function.
This is a generalization of numpy.random.poisson to the case where
`lam` is a tensor.
"""
return random_state.poisson(lam, size)
def poisson(random_state, size=None, lam=1.0, ndim=None, dtype='int64'):
"""
Draw samples from a Poisson distribution.
The Poisson distribution is the limit of the Binomial distribution for large N.
:param lam: float or ndarray-like of the same shape as size parameter
Expectation of interval, should be >= 0.
:param size: int or tuple of ints, optional
Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn.
:param dtype: the dtype of the return value (which will represent counts)
size or ndim must be given
"""
lam = tensor.as_tensor_variable(lam)
ndim, size, bcast = _infer_ndim_bcast(ndim, size)
op = RandomFunction(poisson_helper, tensor.TensorType(dtype=dtype,
broadcastable=bcast))
return op(random_state, size, lam)
def permutation_helper(random_state, n, shape): def permutation_helper(random_state, n, shape):
"""Helper function to generate permutations from integers. """Helper function to generate permutations from integers.
...@@ -894,6 +928,18 @@ class RandomStreamsBase(object): ...@@ -894,6 +928,18 @@ class RandomStreamsBase(object):
""" """
return self.gen(choice, size, a, replace, p, ndim=ndim, dtype=dtype) return self.gen(choice, size, a, replace, p, ndim=ndim, dtype=dtype)
def poisson(self, size=None, lam=None, ndim=None, dtype='int64'):
"""
Draw samples from a Poisson distribution.
The Poisson distribution is the limit of the Binomial distribution for large N.
If the size argument is ambiguous on the number of dimensions,
ndim may be a plain integer to supplement the missing
information.
"""
return self.gen(poisson, size, lam, ndim=ndim, dtype=dtype)
def permutation(self, size=None, n=1, ndim=None, dtype='int64'): def permutation(self, size=None, n=1, ndim=None, dtype='int64'):
""" """
Returns permutations of the integers between 0 and n-1, as many times Returns permutations of the integers between 0 and n-1, as many times
......
...@@ -505,6 +505,33 @@ class T_random_function(utt.InferShapeTester): ...@@ -505,6 +505,33 @@ class T_random_function(utt.InferShapeTester):
self.assertTrue(numpy.allclose(val0, numpy_val0)) self.assertTrue(numpy.allclose(val0, numpy_val0))
self.assertTrue(numpy.allclose(val1, numpy_val1)) self.assertTrue(numpy.allclose(val1, numpy_val1))
def test_poisson(self):
"""Test that raw_random.poisson generates the same
results as numpy."""
# Check over two calls to see if the random state is correctly updated.
rng_R = random_state_type()
# Use non-default parameters, and larger dimensions because of
# the integer nature of the result
post_r, out = poisson(rng_R, lam=5, size=(11,8))
f = compile.function(
[compile.In(rng_R,
value=numpy.random.RandomState(utt.fetch_seed()),
update=post_r, mutable=True)],
[out], accept_inplace=True)
numpy_rng = numpy.random.RandomState(utt.fetch_seed())
val0 = f()
val1 = f()
numpy_val0 = numpy_rng.poisson(5,size=(11,8))
numpy_val1 = numpy_rng.poisson(5,size=(11,8))
print val0
print numpy_val0
print val1
print numpy_val1
self.assertTrue(numpy.allclose(val0, numpy_val0))
self.assertTrue(numpy.allclose(val1, numpy_val1))
def test_permutation(self): def test_permutation(self):
"""Test that raw_random.permutation generates the same """Test that raw_random.permutation generates the same
results as numpy.""" results as numpy."""
......
...@@ -205,6 +205,23 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -205,6 +205,23 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_val0 == numpy_val0) assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1) assert numpy.all(fn_val1 == numpy_val1)
def test_poisson(self):
"""Test that RandomStreams.poisson generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated.
random = RandomStreams(utt.fetch_seed())
fn = function([], random.poisson(lam=5, size=(11, 8)))
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
numpy_val0 = rng.poisson(lam=5, size=(11, 8))
numpy_val1 = rng.poisson(lam=5, size=(11, 8))
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_permutation(self): def test_permutation(self):
"""Test that RandomStreams.permutation generates the same results as numpy""" """Test that RandomStreams.permutation generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论