提交 6a159276 authored 作者: bergstra@tikuanyin's avatar bergstra@tikuanyin

added a randomstreams for shared variables

上级 ee2f8a24
"""Define RandomStreams, providing random number variables for Theano graphs."""
__docformat__ = "restructuredtext en"
import sys
import numpy
from ...gof import Container
from ...tensor import raw_random
from sharedvalue import SharedVariable, shared_constructor, shared
class RandomStateSharedVariable(SharedVariable):
pass
@shared_constructor
def randomstate_constructor(value, name=None, strict=False):
"""SharedVariable Constructor for RandomState"""
if not isinstance(value, numpy.random.RandomState):
raise TypeError
return RandomStateSharedVariable(
type=raw_random.random_state_type,
value=value,
name=name,
strict=strict)
class RandomStreams(object):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)"""
random_state_variables = []
"""A list of pairs of the form (input_r, output_r). This will be over-ridden by the module
instance to contain stream generators.
"""
default_instance_seed = None
"""Instance variable should take None or integer value. Used to seed the random number
generator that provides seeds for member streams"""
gen_seedgen = None
"""numpy.RandomState instance that gen() uses to seed new streams.
"""
def updates(self):
return list(self.random_state_variables)
def __init__(self, seed=None):
"""
:type seed: None or int
:param seed: a default seed to initialize the RandomState instances after build. See
`RandomStreamsInstance.__init__` for more details.
"""
super(RandomStreams, self).__init__()
self.random_state_variables = []
self.default_instance_seed = seed
self.gen_seedgen = numpy.random.RandomState(seed)
def seed(self, seed=None):
"""Re-initialize each random stream
:param seed: each random stream will be assigned a unique state that depends
deterministically on this value.
:type seed: None or integer in range 0 to 2**30
:rtype: None
"""
seed = self.default_instance_seed if seed is None else seed
seedgen = numpy.random.RandomState(seed)
for old_r, new_r in self.random_state_variables:
old_r_seed = seedgen.randint(2**30)
old_r.value = numpy.random.RandomState(int(old_r_seed))
def __getitem__(self, item):
"""Retrieve the numpy RandomState instance associated with a particular stream
:param item: a variable of type RandomStateType, associated with this RandomStream
:rtype: numpy RandomState (or None, before initialize)
:note: This is kept for compatibility with `tensor.randomstreams.RandomStreams`. The
simpler syntax ``item.rng.value`` is also valid.
"""
return item.value
def __setitem__(self, item, val):
"""Set the numpy RandomState instance associated with a particular stream
:param item: a variable of type RandomStateType, associated with this RandomStream
:param val: the new value
:type val: numpy RandomState
:rtype: None
:note: This is kept for compatibility with `tensor.randomstreams.RandomStreams`. The
simpler syntax ``item.rng.value = val`` is also valid.
"""
item.value = val
def gen(self, op, *args, **kwargs):
"""Create a new random stream in this container.
:param op: a RandomFunction instance to
:param args: interpreted by `op`
:param kwargs: interpreted by `op`
:returns: The symbolic random draw part of op()'s return value. This function stores
the updated RandomStateType Variable for use at `build` time.
:rtype: TensorVariable
"""
seed = int(self.gen_seedgen.randint(2**30))
random_state_variable = shared(numpy.random.RandomState(seed))
new_r, out = op(random_state_variable, *args, **kwargs)
out.rng = random_state_variable
self.random_state_variables.append((random_state_variable, new_r))
return out
def binomial(self, *args, **kwargs):
"""Return a symbolic binomial sample
This is a shortcut for a call to `self.gen`
"""
return self.gen(raw_random.binomial, *args, **kwargs)
def uniform(self, *args, **kwargs):
"""Return a symbolic uniform sample
This is a shortcut for a call to `self.gen`
"""
return self.gen(raw_random.uniform, *args, **kwargs)
def normal(self, *args, **kwargs):
"""Return a symbolic normal sample
This is a shortcut for a call to `self.gen`
"""
return self.gen(raw_random.normal, *args, **kwargs)
def random_integers(self, *args, **kwargs):
"""Return a symbolic random integer sample
This is a shortcut for a call to `self.gen`
"""
return self.gen(raw_random.random_integers, *args, **kwargs)
__docformat__ = "restructuredtext en"
import sys
import unittest
import numpy
from theano.tensor import raw_random
from theano.compile.sandbox.shared_randomstreams import RandomStreams
from theano.compile.sandbox.pfunc import pfunc
from theano import tensor
from theano import compile, gof
class T_RandomStreams(unittest.TestCase):
def test_basics(self):
random = RandomStreams(234)
fn = pfunc([], random.uniform((2,2)), updates=random.updates())
gn = pfunc([], random.normal((2,2)), updates=random.updates())
fn_val0 = fn()
fn_val1 = fn()
gn_val0 = gn()
rng_seed = numpy.random.RandomState(234).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
#print numpy_val0
assert numpy.all(fn_val0 == numpy_val0)
print fn_val0
print numpy_val0
print fn_val1
print numpy_val1
assert numpy.all(fn_val1 == numpy_val1)
def test_seed_fn(self):
random = RandomStreams(234)
fn = pfunc([], random.uniform((2,2)), updates=random.updates())
random.seed(888)
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(888).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
#print numpy_val0
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_getitem(self):
random = RandomStreams(234)
out = random.uniform((2,2))
fn = pfunc([], out, updates=random.updates())
random.seed(888)
rng = numpy.random.RandomState()
rng.set_state(random[out.rng].get_state()) #tests getitem
fn_val0 = fn()
fn_val1 = fn()
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_setitem(self):
random = RandomStreams(234)
out = random.uniform((2,2))
fn = pfunc([], out, updates=random.updates())
random.seed(888)
rng = numpy.random.RandomState(823874)
random[out.rng] = numpy.random.RandomState(823874)
fn_val0 = fn()
fn_val1 = fn()
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
if __name__ == '__main__':
from theano.tests import main
main("test_randomstreams")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论