提交 bb429155 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

numpy.random.choice is always available from numpy 1.7 onwards

上级 17e9d105
......@@ -647,11 +647,6 @@ def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None,
If size is None, a scalar will be returned.
"""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise ImportError('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
a = tensor.as_tensor_variable(a)
if isinstance(replace, bool):
replace = tensor.constant(replace, dtype='int8')
......
......@@ -459,12 +459,6 @@ class T_random_function(utt.InferShapeTester):
def test_choice(self):
"""Test that raw_random.choice generates the same
results as numpy."""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
# Check over two calls to see if the random state is correctly updated.
rng_R = random_state_type()
# Use non-default parameters, and larger dimensions because of
......
......@@ -190,12 +190,6 @@ class T_SharedRandomStreams(unittest.TestCase):
def test_choice(self):
"""Test that RandomStreams.choice generates the same results as numpy"""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
# Check over two calls to see if the random state is correctly updated.
random = RandomStreams(utt.fetch_seed())
fn = function([], random.choice((11, 8), 10, 1, 0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论