提交 9b304769 authored 作者: James Bergstra's avatar James Bergstra

rng_mrg - writing it

上级 33a2a9f2
...@@ -517,7 +517,7 @@ class GPU_mrg_uniform(mrg_uniform_base): ...@@ -517,7 +517,7 @@ class GPU_mrg_uniform(mrg_uniform_base):
PyErr_Format(PyExc_ValueError, "rstate len must be multiple of 6"); PyErr_Format(PyExc_ValueError, "rstate len must be multiple of 6");
%(fail)s; %(fail)s;
} }
n_streams = CudaNdarray_HOST_DIMS(%(o_rstate)s)[0]/6; n_streams = std::min(CudaNdarray_HOST_DIMS(%(o_rstate)s)[0]/6, n_elements);
{ {
unsigned int threads_per_block = std::min(n_streams, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK); unsigned int threads_per_block = std::min(n_streams, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
...@@ -549,7 +549,7 @@ class GPU_mrg_uniform(mrg_uniform_base): ...@@ -549,7 +549,7 @@ class GPU_mrg_uniform(mrg_uniform_base):
class MRG_RandomStreams(object): class MRG_RandomStreams(object):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)""" """Module component with similar interface to numpy.random (numpy.random.RandomState)"""
def __init__(self, seed=None, use_cuda=None): def __init__(self, seed=12345, use_cuda=None):
""" """
:type seed: None or int :type seed: None or int
...@@ -557,7 +557,12 @@ class MRG_RandomStreams(object): ...@@ -557,7 +557,12 @@ class MRG_RandomStreams(object):
`RandomStreamsInstance.__init__` for more details. `RandomStreamsInstance.__init__` for more details.
""" """
super(MRG_RandomStreams, self).__init__() super(MRG_RandomStreams, self).__init__()
self.rstate = numpy.asarray([12345]*6, dtype='int32') if isinstance(seed, int):
self.rstate = numpy.asarray([seed]*6, dtype='int32')
elif len(seed)==6:
self.rstate = numpy.asarray(seed, dtype='int32')
else:
raise TypeError("seed should be 1 integer or 6 integers")
if use_cuda is None: if use_cuda is None:
self.use_cuda = cuda_enabled self.use_cuda = cuda_enabled
else: else:
...@@ -583,10 +588,19 @@ class MRG_RandomStreams(object): ...@@ -583,10 +588,19 @@ class MRG_RandomStreams(object):
return rval return rval
def n_streams(self, size): def n_streams(self, size):
if isinstance(size, (tuple, list)):
r = 1 r = 1
for s in size: for s in size:
r *= s r *= s
return r return r
try:
rval = int(size)
assert rval > 0
return rval
except:
pass
print >> sys.stderr, "MRG_RandomStreams Can't determine #streams from size (%s), guessing 30*256"%str(size)
return 30*256
def pretty_return(self, node_rstate, new_rstate, sample): def pretty_return(self, node_rstate, new_rstate, sample):
sample.rstate = node_rstate sample.rstate = node_rstate
...@@ -630,6 +644,12 @@ class MRG_RandomStreams(object): ...@@ -630,6 +644,12 @@ class MRG_RandomStreams(object):
raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments') raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments')
return r return r
def binomial(self, size=None, n=1, prob=0.5, ndim=None, dtype='int64'):
if n == 1:
return cast(self.uniform(size=size) < prob, dtype)
else:
raise NotImplementedError("MRG_RandomStreams.binomial with n > 1")
@local_optimizer([None]) @local_optimizer([None])
def mrg_random_make_inplace(node): def mrg_random_make_inplace(node):
op = node.op op = node.op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论