提交 9b304769 authored 作者: James Bergstra's avatar James Bergstra

rng_mrg - writing it

上级 33a2a9f2
......@@ -517,7 +517,7 @@ class GPU_mrg_uniform(mrg_uniform_base):
PyErr_Format(PyExc_ValueError, "rstate len must be multiple of 6");
%(fail)s;
}
n_streams = CudaNdarray_HOST_DIMS(%(o_rstate)s)[0]/6;
n_streams = std::min(CudaNdarray_HOST_DIMS(%(o_rstate)s)[0]/6, n_elements);
{
unsigned int threads_per_block = std::min(n_streams, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
......@@ -549,7 +549,7 @@ class GPU_mrg_uniform(mrg_uniform_base):
class MRG_RandomStreams(object):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)"""
def __init__(self, seed=None, use_cuda=None):
def __init__(self, seed=12345, use_cuda=None):
"""
:type seed: None or int
......@@ -557,7 +557,12 @@ class MRG_RandomStreams(object):
`RandomStreamsInstance.__init__` for more details.
"""
super(MRG_RandomStreams, self).__init__()
self.rstate = numpy.asarray([12345]*6, dtype='int32')
if isinstance(seed, int):
self.rstate = numpy.asarray([seed]*6, dtype='int32')
elif len(seed)==6:
self.rstate = numpy.asarray(seed, dtype='int32')
else:
raise TypeError("seed should be 1 integer or 6 integers")
if use_cuda is None:
self.use_cuda = cuda_enabled
else:
......@@ -583,10 +588,19 @@ class MRG_RandomStreams(object):
return rval
def n_streams(self, size):
r = 1
for s in size:
r *= s
return r
if isinstance(size, (tuple, list)):
r = 1
for s in size:
r *= s
return r
try:
rval = int(size)
assert rval > 0
return rval
except:
pass
print >> sys.stderr, "MRG_RandomStreams Can't determine #streams from size (%s), guessing 30*256"%str(size)
return 30*256
def pretty_return(self, node_rstate, new_rstate, sample):
sample.rstate = node_rstate
......@@ -630,6 +644,12 @@ class MRG_RandomStreams(object):
raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments')
return r
def binomial(self, size=None, n=1, prob=0.5, ndim=None, dtype='int64'):
if n == 1:
return cast(self.uniform(size=size) < prob, dtype)
else:
raise NotImplementedError("MRG_RandomStreams.binomial with n > 1")
@local_optimizer([None])
def mrg_random_make_inplace(node):
op = node.op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论