提交 57fd59cb authored 作者: Xavier Bouthillier/'s avatar Xavier Bouthillier/

(re)added poisson to RandomStreamsBase add added a unit test

上级 a8178d62
...@@ -629,6 +629,7 @@ def poisson_helper(random_state, lam, size): ...@@ -629,6 +629,7 @@ def poisson_helper(random_state, lam, size):
This is a generalization of numpy.random.poisson to the case where This is a generalization of numpy.random.poisson to the case where
`lam` is a tensor. `lam` is a tensor.
""" """
return random_state.poisson(lam, size) return random_state.poisson(lam, size)
def poisson(random_state, size=None, lam=1.0, ndim=None, dtype='int64'): def poisson(random_state, size=None, lam=1.0, ndim=None, dtype='int64'):
...@@ -637,8 +638,7 @@ def poisson(random_state, size=None, lam=1.0, ndim=None, dtype='int64'): ...@@ -637,8 +638,7 @@ def poisson(random_state, size=None, lam=1.0, ndim=None, dtype='int64'):
The Poisson distribution is the limit of the Binomial distribution for large N. The Poisson distribution is the limit of the Binomial distribution for large N.
:param lam: float :param lam: float or ndarray-like of the same shape as size parameter
Expectation of interval, should be >= 0. Expectation of interval, should be >= 0.
:param size: int or tuple of ints, optional :param size: int or tuple of ints, optional
...@@ -928,6 +928,18 @@ class RandomStreamsBase(object): ...@@ -928,6 +928,18 @@ class RandomStreamsBase(object):
""" """
return self.gen(choice, size, a, replace, p, ndim=ndim, dtype=dtype) return self.gen(choice, size, a, replace, p, ndim=ndim, dtype=dtype)
def poisson(self, size=None, lam=None, ndim=None, dtype='int64'):
"""
Draw samples from a Poisson distribution.
The Poisson distribution is the limit of the Binomial distribution for large N.
If the size argument is ambiguous on the number of dimensions,
ndim may be a plain integer to supplement the missing
information.
"""
return self.gen(poisson, size, lam, ndim=ndim, dtype=dtype)
def permutation(self, size=None, n=1, ndim=None, dtype='int64'): def permutation(self, size=None, n=1, ndim=None, dtype='int64'):
""" """
Returns permutations of the integers between 0 and n-1, as many times Returns permutations of the integers between 0 and n-1, as many times
......
...@@ -205,6 +205,23 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -205,6 +205,23 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_val0 == numpy_val0) assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1) assert numpy.all(fn_val1 == numpy_val1)
def test_poisson(self):
"""Test that RandomStreams.poisson generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated.
random = RandomStreams(utt.fetch_seed())
fn = function([], random.poisson(lam=5, size=(11, 8)))
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
numpy_val0 = rng.poisson(lam=5, size=(11, 8))
numpy_val1 = rng.poisson(lam=5, size=(11, 8))
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_permutation(self): def test_permutation(self):
"""Test that RandomStreams.permutation generates the same results as numpy""" """Test that RandomStreams.permutation generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论