提交 57fd59cb authored 作者: Xavier Bouthillier/'s avatar Xavier Bouthillier/

(re)added poisson to RandomStreamsBase add added a unit test

上级 a8178d62
......@@ -629,6 +629,7 @@ def poisson_helper(random_state, lam, size):
This is a generalization of numpy.random.poisson to the case where
`lam` is a tensor.
"""
return random_state.poisson(lam, size)
def poisson(random_state, size=None, lam=1.0, ndim=None, dtype='int64'):
......@@ -637,8 +638,7 @@ def poisson(random_state, size=None, lam=1.0, ndim=None, dtype='int64'):
The Poisson distribution is the limit of the Binomial distribution for large N.
:param lam: float
:param lam: float or ndarray-like of the same shape as size parameter
Expectation of interval, should be >= 0.
:param size: int or tuple of ints, optional
......@@ -928,6 +928,18 @@ class RandomStreamsBase(object):
"""
return self.gen(choice, size, a, replace, p, ndim=ndim, dtype=dtype)
def poisson(self, size=None, lam=None, ndim=None, dtype='int64'):
"""
Draw samples from a Poisson distribution.
The Poisson distribution is the limit of the Binomial distribution for large N.
If the size argument is ambiguous on the number of dimensions,
ndim may be a plain integer to supplement the missing
information.
"""
return self.gen(poisson, size, lam, ndim=ndim, dtype=dtype)
def permutation(self, size=None, n=1, ndim=None, dtype='int64'):
"""
Returns permutations of the integers between 0 and n-1, as many times
......
......@@ -205,6 +205,23 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_poisson(self):
"""Test that RandomStreams.poisson generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated.
random = RandomStreams(utt.fetch_seed())
fn = function([], random.poisson(lam=5, size=(11, 8)))
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
numpy_val0 = rng.poisson(lam=5, size=(11, 8))
numpy_val1 = rng.poisson(lam=5, size=(11, 8))
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_permutation(self):
"""Test that RandomStreams.permutation generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论