提交 12fa8f65 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add sanity check on seed in rng_mrg, to ensure the results' validity

上级 f3fedb46
...@@ -553,7 +553,7 @@ class GPU_mrg_uniform(mrg_uniform_base): ...@@ -553,7 +553,7 @@ class GPU_mrg_uniform(mrg_uniform_base):
{ {
PyErr_Format(PyExc_RuntimeError, "Cuda error: %%s: %%s.\\n", "mrg_uniform", cudaGetErrorString(err)); PyErr_Format(PyExc_RuntimeError, "Cuda error: %%s: %%s.\\n", "mrg_uniform", cudaGetErrorString(err));
%(fail)s; %(fail)s;
} }
} }
//////// </ code generated by mrg_uniform> //////// </ code generated by mrg_uniform>
...@@ -564,15 +564,31 @@ class MRG_RandomStreams(object): ...@@ -564,15 +564,31 @@ class MRG_RandomStreams(object):
def __init__(self, seed=12345, use_cuda=None): def __init__(self, seed=12345, use_cuda=None):
""" """
:type seed: None or int :type seed: int or list of 6 int.
:param seed: a default seed to initialize the random state.
If a single int is given, it will be replicated 6 times.
The first 3 values of the seed must all be less than M1 = 2147483647,
and not all 0; and the last 3 values must all be less than
M2 = 2147462579, and not all 0.
:param seed: a default seed to initialize the RandomState instances after build. See
`RandomStreamsInstance.__init__` for more details.
""" """
super(MRG_RandomStreams, self).__init__() super(MRG_RandomStreams, self).__init__()
if isinstance(seed, int): if isinstance(seed, int):
if seed == 0:
raise ValueError('seed should not be 0', seed)
elif seed >= M2:
raise ValueError('seed should be less than %i' % M2, seed)
self.rstate = numpy.asarray([seed]*6, dtype='int32') self.rstate = numpy.asarray([seed]*6, dtype='int32')
elif len(seed)==6: elif len(seed)==6:
if seed[0] == 0 and seed[1] == 0 and seed[2] == 0:
raise ValueError('The first 3 values of seed should not be all 0', seed)
if seed[3] == 0 and seed[4] == 0 and seed[5] == 0:
raise ValueError('The last 3 values of seed should not be all 0', seed)
if seed[0] >= M1 or seed[1] >= M1 or seed[2] >= M1:
raise ValueError('The first 3 values of seed should be less than %i' % M1, seed)
if seed[3] >= M2 or seed[4] >= M2 or seed[5] >= M2:
raise ValueError('The last 3 values of seed should be less than %i' % M2, seed)
self.rstate = numpy.asarray(seed, dtype='int32') self.rstate = numpy.asarray(seed, dtype='int32')
else: else:
raise TypeError("seed should be 1 integer or 6 integers") raise TypeError("seed should be 1 integer or 6 integers")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论