提交 fd40e50e authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1607 from sisp/raw_random_choice

added numpy.random.choice function
......@@ -575,6 +575,54 @@ def random_integers(random_state, size=None, low=0, high=1, ndim=None,
return op(random_state, size, low, high)
def choice_helper(random_state, a, replace, p, size):
"""
Helper function to draw random numbers using numpy's choice function.
This is a generalization of numpy.random.choice to the case where `a`,
`replace` and `p` are tensors.
"""
if a.ndim > 1:
raise ValueError('a.ndim (%i) must be 0 or 1' % a.ndim)
if p.ndim == 1:
if p.size == 0:
p = None
else:
raise ValueError('p.ndim (%i) must be 1' % p.ndim)
replace = bool(replace)
return random_state.choice(a, size, replace, p)
def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None,
dtype='int64'):
"""
Choose values from `a` with or without replacement. `a` can be a 1-D array
or a positive scalar. If `a` is a scalar, the samples are drawn from the
range 0,...,a-1.
If the size argument is ambiguous on the number of dimensions, ndim
may be a plain integer to supplement the missing information.
If size is None, a scalar will be returned.
"""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise ImportError('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
a = tensor.as_tensor_variable(a)
if isinstance(replace, bool):
replace = tensor.constant(replace, dtype='int8')
else:
replace = tensor.as_tensor_variable(replace)
# encode p=None as an empty vector
p = tensor.as_tensor_variable(p or [])
ndim, size, bcast = _infer_ndim_bcast(ndim, size)
op = RandomFunction(choice_helper, tensor.TensorType(dtype=dtype,
broadcastable=bcast))
return op(random_state, size, a, replace, p)
def permutation_helper(random_state, n, shape):
"""Helper function to generate permutations from integers.
......@@ -832,6 +880,19 @@ class RandomStreamsBase(object):
"""
return self.gen(random_integers, size, low, high, ndim=ndim,
dtype=dtype)
def choice(self, size=None, a=2, replace=True, p=None, ndim=None,
dtype='int64'):
"""
Choose values from `a` with or without replacement. `a` can be a 1-D
array or a positive scalar. If `a` is a scalar, the samples are drawn
from the range 0,...,a-1.
If the size argument is ambiguous on the number of dimensions,
ndim may be a plain integer to supplement the missing
information.
"""
return self.gen(choice, size, a, replace, p, ndim=ndim, dtype=dtype)
def permutation(self, size=None, n=1, ndim=None, dtype='int64'):
"""
......
__docformat__ = "restructuredtext en"
import sys
import unittest
import numpy as N
import numpy
from theano.tests import unittest_tools as utt
from theano.tensor.raw_random import *
from theano.tensor import (raw_random, ivector, dvector, iscalar, dcol,
dtensor3)
from theano.tests import unittest_tools as utt
from theano import tensor
from theano import compile, config, gof
......@@ -474,6 +471,39 @@ class T_random_function(utt.InferShapeTester):
update=post_r2, mutable=True)],
[out2], accept_inplace=True)
self.assertRaises(ValueError, f2)
def test_choice(self):
"""Test that raw_random.choice generates the same
results as numpy."""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
# Check over two calls to see if the random state is correctly updated.
rng_R = random_state_type()
# Use non-default parameters, and larger dimensions because of
# the integer nature of the result
post_r, out = choice(rng_R, (11, 8), 10, 1, 0)
f = compile.function(
[compile.In(rng_R,
value=numpy.random.RandomState(utt.fetch_seed()),
update=post_r, mutable=True)],
[out], accept_inplace=True)
numpy_rng = numpy.random.RandomState(utt.fetch_seed())
val0 = f()
val1 = f()
numpy_val0 = numpy_rng.choice(10, (11, 8), True, None)
numpy_val1 = numpy_rng.choice(10, (11, 8), True, None)
print val0
print numpy_val0
print val1
print numpy_val1
self.assertTrue(numpy.allclose(val0, numpy_val0))
self.assertTrue(numpy.allclose(val1, numpy_val1))
def test_permutation(self):
"""Test that raw_random.permutation generates the same
......
......@@ -182,6 +182,28 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_choice(self):
"""Test that RandomStreams.choice generates the same results as numpy"""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
# Check over two calls to see if the random state is correctly updated.
random = RandomStreams(utt.fetch_seed())
fn = function([], random.choice((11, 8), 10, 1, 0))
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
numpy_val0 = rng.choice(10, (11, 8), True, None)
numpy_val1 = rng.choice(10, (11, 8), True, None)
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_permutation(self):
"""Test that RandomStreams.permutation generates the same results as numpy"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论