提交 bb429155 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

numpy.random.choice is always available from numpy 1.7 onwards

上级 17e9d105
...@@ -647,11 +647,6 @@ def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None, ...@@ -647,11 +647,6 @@ def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None,
If size is None, a scalar will be returned. If size is None, a scalar will be returned.
""" """
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise ImportError('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
a = tensor.as_tensor_variable(a) a = tensor.as_tensor_variable(a)
if isinstance(replace, bool): if isinstance(replace, bool):
replace = tensor.constant(replace, dtype='int8') replace = tensor.constant(replace, dtype='int8')
......
...@@ -459,12 +459,6 @@ class T_random_function(utt.InferShapeTester): ...@@ -459,12 +459,6 @@ class T_random_function(utt.InferShapeTester):
def test_choice(self): def test_choice(self):
"""Test that raw_random.choice generates the same """Test that raw_random.choice generates the same
results as numpy.""" results as numpy."""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
rng_R = random_state_type() rng_R = random_state_type()
# Use non-default parameters, and larger dimensions because of # Use non-default parameters, and larger dimensions because of
......
...@@ -190,12 +190,6 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -190,12 +190,6 @@ class T_SharedRandomStreams(unittest.TestCase):
def test_choice(self): def test_choice(self):
"""Test that RandomStreams.choice generates the same results as numpy""" """Test that RandomStreams.choice generates the same results as numpy"""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
random = RandomStreams(utt.fetch_seed()) random = RandomStreams(utt.fetch_seed())
fn = function([], random.choice((11, 8), 10, 1, 0)) fn = function([], random.choice((11, 8), 10, 1, 0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论