提交 7999e2d3 authored 作者: James Bergstra's avatar James Bergstra

sandbox/rng_mrg - committed parallel version

上级 d6d8c839
"""
Implementation of MRG31k3p random number generator for Theano
Generator code in SSJ package (L'Ecuyer & Simard)
http://www.iro.umontreal.ca/~simardr/ssj/indexe.html
"""
import sys import sys
import numpy import numpy
...@@ -59,7 +66,7 @@ def ff_2p134(rstate): ...@@ -59,7 +66,7 @@ def ff_2p134(rstate):
def ff_2p72(rstate): def ff_2p72(rstate):
return multMatVect(rstate, A1p72, M1, A2p72, M2) return multMatVect(rstate, A1p72, M1, A2p72, M2)
def mrg_next_value(rstate): def mrg_next_value(rstate, new_rstate):
x11, x12, x13, x21, x22, x23 = rstate x11, x12, x13, x21, x22, x23 = rstate
assert type(x11) == numpy.int32 assert type(x11) == numpy.int32
...@@ -100,20 +107,28 @@ def mrg_next_value(rstate): ...@@ -100,20 +107,28 @@ def mrg_next_value(rstate):
x21 = y2; x21 = y2;
# Must never return either 0 or M1+1 # Must never return either 0 or M1+1
new_rstate = numpy.asarray([x11, x12, x13, x21, x22, x23]) new_rstate[...] = [x11, x12, x13, x21, x22, x23]
assert new_rstate.dtype == numpy.int32 assert new_rstate.dtype == numpy.int32
if (x11 <= x21): if (x11 <= x21):
return (x11 - x21 + M1) * NORM, new_rstate return (x11 - x21 + M1) * NORM
else: else:
return (x11 - x21) * NORM, new_rstate return (x11 - x21) * NORM
class mrg_uniform(Op): class mrg_uniform(Op):
def __init__(self, output_type): def __init__(self, output_type, inplace=False):
self.output_type = output_type self.output_type = output_type
self.inplace=inplace
def __eq__(self, other):
return type(self) == type(other) \
and self.output_type == other.output_type \
and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ hash(self.output_type) ^ hash(self.inplace)
@classmethod @classmethod
def apply(cls, rstate, ndim, dtype, size, low, high): def new(cls, rstate, ndim, dtype, size, low, high):
v_size = as_tensor_variable(size) v_size = as_tensor_variable(size)
if ndim is None: if ndim is None:
ndim = get_vector_length(v_size) ndim = get_vector_length(v_size)
...@@ -126,21 +141,24 @@ class mrg_uniform(Op): ...@@ -126,21 +141,24 @@ class mrg_uniform(Op):
[rstate.type(), self.output_type()]) [rstate.type(), self.output_type()])
def perform(self, node, (rstate, size, low, high), (o_rstate, o_sample)): def perform(self, node, (rstate, size, low, high), (o_rstate, o_sample)):
n_elements = 1 n_elements = 1
rstate = rstate.copy() if not self.inplace:
rstate = rstate.copy()
for s in size: for s in size:
n_elements *= s n_elements *= s
n_streams,_ = rstate.shape
rval = numpy.zeros(n_elements, dtype=self.output_type.dtype) rval = numpy.zeros(n_elements, dtype=self.output_type.dtype)
for i in xrange(n_elements): for i in xrange(n_elements):
sample, rstate = mrg_next_value(rstate) sample = mrg_next_value(rstate[i%n_streams], rstate[i%n_streams])
rval[i] = sample rval[i] = sample
o_rstate[0] = rstate.copy() o_rstate[0] = rstate.copy()
o_sample[0] = rval.reshape(size) o_sample[0] = rval.reshape(size)
class MRG_RandomStreams(raw_random.RandomStreamsBase): class MRG_RandomStreams(object):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)""" """Module component with similar interface to numpy.random (numpy.random.RandomState)"""
def __init__(self, seed=None): def __init__(self, seed=None):
...@@ -154,33 +172,49 @@ class MRG_RandomStreams(raw_random.RandomStreamsBase): ...@@ -154,33 +172,49 @@ class MRG_RandomStreams(raw_random.RandomStreamsBase):
self.rstate = numpy.asarray([12345]*6, dtype='int32') self.rstate = numpy.asarray([12345]*6, dtype='int32')
def inc_rstate(self): def inc_rstate(self):
"""Skip self.rstate forward to the next stream point""" """Update self.rstate to be skipped 2^134 steps forward to the next stream start"""
print >> sys.stderr, "TODO: skip forward the state" self.rstate = ff_2p134(self.rstate)
assert self.rstate.dtype == numpy.int32
def gen(self, op, *args, **kwargs):
"""Create a new random stream in this container.
:param op: one of the functions in numpy.raw_random
:param args: interpreted by `op` def get_substream_rstates(self, n_streams, inc_rstate=True):
"""Initialize a matrix in which each row is a MRG stream state,
and they are spaced by 2**72 samples.
"""
assert n_streams < 2**72
assert n_streams > 0
rval = numpy.zeros((n_streams,6), dtype='int32')
rval[0] = self.rstate
for i in xrange(1, n_streams):
rval[i] = ff_2p72(rval[i-1])
if inc_rstate:
self.inc_rstate()
return rval
def n_streams(self, size):
r = 1
for s in size:
r *= s
return r
:param kwargs: interpreted by `op` def pretty_return(self, node_rstate, new_rstate, sample):
sample.rstate = node_rstate
sample.update = (node_rstate, new_rstate)
node_rstate.default_update = new_rstate
return sample
:returns: The symbolic random draw part of op()'s return value. This function stores
the updated RandomStateType Variable for use at `build` time.
:rtype: TensorVariable def uniform(self, size=None, low=0.0, high=1.0, ndim=None, dtype=config.floatX):
""" """
ndim = kwargs.pop('ndim', None) Sample a tensor of given size whose element from a uniform
dtype = kwargs.pop('dtype', None) distribution between low and high.
assert dtype is not None
node_rstate = shared(self.rstate.copy())
new_r, sample = globals()['mrg_'+op.__name__].apply(node_rstate, ndim, dtype, *args, **kwargs)
sample.rstate = node_rstate
sample.update = (node_rstate, new_r)
node_rstate.default_update = new_r
return sample
If the size argument is ambiguous on the number of dimensions,
ndim may be a plain integer to supplement the missing
information.
"""
node_rstate = shared(self.get_substream_rstates(self.n_streams(size)))
return self.pretty_return(node_rstate,
*mrg_uniform.new(node_rstate, ndim, dtype, size, low, high))
# #
# #
...@@ -197,7 +231,9 @@ def test_rng0(): ...@@ -197,7 +231,9 @@ def test_rng0():
f = theano.function([], u) f = theano.function([], u)
print 'random sample?', f() print 'random?', f()
print 'random sample?', f() print 'random?', f()
print 'random sample?', f()
print 'random sample?', f() l = [f() for i in xrange(1000)]
print 'mean', numpy.mean(l), numpy.std(l) / numpy.sqrt(1000)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论