提交 60481d47 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

added check to ensure numpy version >= 1.7 is installed in order to use

the choice function
上级 af18317e
...@@ -602,6 +602,11 @@ def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None, ...@@ -602,6 +602,11 @@ def choice(random_state, size=None, a=2, replace=True, p=None, ndim=None,
If size is None, a scalar will be returned. If size is None, a scalar will be returned.
""" """
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise ImportError('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
a = tensor.as_tensor_variable(a) a = tensor.as_tensor_variable(a)
if isinstance(replace, bool): if isinstance(replace, bool):
replace = tensor.constant(replace, dtype='int8') replace = tensor.constant(replace, dtype='int8')
......
...@@ -478,6 +478,12 @@ class T_random_function(utt.InferShapeTester): ...@@ -478,6 +478,12 @@ class T_random_function(utt.InferShapeTester):
def test_choice(self): def test_choice(self):
"""Test that raw_random.choice generates the same """Test that raw_random.choice generates the same
results as numpy.""" results as numpy."""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
rng_R = random_state_type() rng_R = random_state_type()
# Use non-default parameters, and larger dimensions because of # Use non-default parameters, and larger dimensions because of
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论