提交 a8178d62 authored 作者: Xavier Bouthillier/'s avatar Xavier Bouthillier/

Added numpy wrapper for poisson distribution

上级 94506569
......@@ -622,6 +622,40 @@ def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None,
broadcastable=bcast))
return op(random_state, size, a, replace, p)
def poisson_helper(random_state, lam, size):
"""
Helper function to draw random numbers using numpy's poisson function.
This is a generalization of numpy.random.poisson to the case where
`lam` is a tensor.
"""
return random_state.poisson(lam, size)
def poisson(random_state, size=None, lam=1.0, ndim=None, dtype='int64'):
"""
Draw samples from a Poisson distribution.
The Poisson distribution is the limit of the Binomial distribution for large N.
:param lam: float
Expectation of interval, should be >= 0.
:param size: int or tuple of ints, optional
Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn.
:param dtype: the dtype of the return value (which will represent counts)
size or ndim must be given
"""
lam = tensor.as_tensor_variable(lam)
ndim, size, bcast = _infer_ndim_bcast(ndim, size)
op = RandomFunction(poisson_helper, tensor.TensorType(dtype=dtype,
broadcastable=bcast))
return op(random_state, size, lam)
def permutation_helper(random_state, n, shape):
"""Helper function to generate permutations from integers.
......
......@@ -505,6 +505,33 @@ class T_random_function(utt.InferShapeTester):
self.assertTrue(numpy.allclose(val0, numpy_val0))
self.assertTrue(numpy.allclose(val1, numpy_val1))
def test_poisson(self):
"""Test that raw_random.poisson generates the same
results as numpy."""
# Check over two calls to see if the random state is correctly updated.
rng_R = random_state_type()
# Use non-default parameters, and larger dimensions because of
# the integer nature of the result
post_r, out = poisson(rng_R, lam=5, size=(11,8))
f = compile.function(
[compile.In(rng_R,
value=numpy.random.RandomState(utt.fetch_seed()),
update=post_r, mutable=True)],
[out], accept_inplace=True)
numpy_rng = numpy.random.RandomState(utt.fetch_seed())
val0 = f()
val1 = f()
numpy_val0 = numpy_rng.poisson(5,size=(11,8))
numpy_val1 = numpy_rng.poisson(5,size=(11,8))
print val0
print numpy_val0
print val1
print numpy_val1
self.assertTrue(numpy.allclose(val0, numpy_val0))
self.assertTrue(numpy.allclose(val1, numpy_val1))
def test_permutation(self):
"""Test that raw_random.permutation generates the same
results as numpy."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论