提交 d2d06651 authored 作者: James Bergstra's avatar James Bergstra

factored raw_random.py into raw_random.py and rmodule.py. also separated test files

上级 fe6e6a1c
"""Define the tensor toplevel"""
__docformat__ = "restructuredtext en"
from basic import *
import opt
import blas
import raw_random
from raw_random import \
import raw_random, rmodule
from rmodule import \
RandomKit, RModule
random = RandomKit('random')
"""Imitate the numpy.random symbol with a tensor.random one"""
from elemwise import \
DimShuffle, Elemwise, CAReduce
......
"""Random number generation for Theano graphs."""
from .. import gof
import basic as tensor
"""Define random number Type (`RandomStateType`) and Op (`RandomFunction`)."""
__docformat__ = "restructuredtext en"
import sys
from copy import copy
import numpy
import functools
#local imports
import basic as tensor
import opt
from .. import compile
from ..compile import SymbolicInputKit, SymbolicInput
from copy import copy
import sys
RS = numpy.random.RandomState
from .. import gof
from ..compile import optdb
class RandomStateType(gof.Type):
"""A Type wrapper for numpy.RandomState
......@@ -48,7 +45,6 @@ class RandomStateType(gof.Type):
random_state_type = RandomStateType()
class RandomFunction(gof.Op):
"""Op that draws random numbers from a numpy.RandomState object
......@@ -89,7 +85,7 @@ class RandomFunction(gof.Op):
def __setstate__(self, state):
self.state = state
fn, outtype, args, kwargs = state
self.fn = getattr(RS, fn) if isinstance(fn, str) else fn
self.fn = getattr(numpy.random.RandomState, fn) if isinstance(fn, str) else fn
self.outtype = outtype
self.args = tuple(tensor.as_tensor(arg) for arg in args)
self.inplace = kwargs.pop('inplace', False)
......@@ -270,88 +266,6 @@ def random_make_inplace(node):
return RandomFunction(op.fn, op.outtype, *op.args, **dict(inplace=True)).make_node(*node.inputs).outputs
return False
compile.optdb.register('random_make_inplace', opt.in2out(random_make_inplace, ignore_newtrees=True), 99, 'fast_run', 'inplace')
optdb.register('random_make_inplace', opt.in2out(random_make_inplace, ignore_newtrees=True), 99, 'fast_run', 'inplace')
import sys
from functools import partial
from collections import deque
class RandomKit(SymbolicInputKit):
def __init__(self, name, value = None):
super(RandomKit, self).__init__(name)
self.value = value
def gen(self, op, *args, **kwargs):
r = gof.generic()
new_r, out = op(r, *args, **kwargs)
self.add_input(SymbolicInput(r, update = new_r))
out.rng = r
out.auto = self
return out
def distribute(self, value, indices, containers):
rg = partial(numpy.random.RandomState(int(value)).randint, 2**30)
elems = deque(zip(indices, containers))
i = 0
while elems:
index, container = elems.popleft()
while i <= index:
curr = rg()
i += 1
rs = numpy.random.RandomState(int(curr))
container.data = rs
def binomial(self, *args, **kwargs):
return self.gen(binomial, *args, **kwargs)
def uniform(self, *args, **kwargs):
return self.gen(uniform, *args, **kwargs)
def normal(self, *args, **kwargs):
return self.gen(normal, *args, **kwargs)
def random_integers(self, *args, **kwargs):
return self.gen(random_integers, *args, **kwargs)
rk = RandomKit('rk', 0xBAD5EED)
class RModule(compile.Module):
def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents)
self.random = RandomKit('rkit')
self._rkit = compile.KitComponent(self.random)
def __wrapper__(self, x):
x = compile.module.wrap(x)
if isinstance(x, compile.Method):
x.kits += [self.random]
return x
def _instance_seed(self, inst, seed, recursive = True):
seedgen = numpy.random.RandomState(seed)
if recursive:
#Here, we recurse through all the components (inst2) contained in (inst)
#and seeds each subcomponent that is an RModule
for path, c in self.flat_components_map(True):
if isinstance(c, RModule):
inst2 = inst
for name in path:
inst2 = inst2[name]
# A Kit (c._rkit.kit) contains a list of io.SymbolicIn instances
# and the distribute method takes a value (seed), a list of indices
# and a list of corresponding gof.Container instances. In this
# situation it will reseed all the rngs using the containers
# associated to them.
c._rkit.kit.distribute(seedgen.random_integers(2**30),
xrange(len(inst2._rkit)), inst2._rkit)
else:
self._rkit.kit.distribute(seedgen.random_integers(2**30), xrange(len(inst._rkit)), inst._rkit)
## TODO: REDO THESE TESTS
__docformat__ = "restructuredtext en"
import sys
import unittest
import numpy as N
......@@ -62,7 +62,7 @@ class T_random_function(unittest.TestCase):
assert not numpy.allclose(f4, f4b)
def test_inplace_optimization(self):
"""Test that arguments to RandomFunction are honored"""
"""Test that FAST_RUN includes the random_make_inplace optimization"""
#inplace = False
rf2 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0)
rng_R = random_state_type()
......@@ -90,62 +90,6 @@ class T_random_function(unittest.TestCase):
assert not numpy.allclose(val0, val1)
class T_test_module(unittest.TestCase):
def test_state_propagation(self):
x = tensor.vector()
rk = RandomKit('rk', 1000)
f = compile.function([x, (rk, [gof.Container(r = gof.generic, storage = [123], name='bla')])], rk.binomial(tensor.shape(x)))
print "RK", rk.value
f['rk'] = 9873456
print "RK", rk.value
rvals = [f([1,2,3,4,6, 7, 8]) for i in xrange(5)]
print rvals
for i in xrange(5-1):
for j in xrange(i+1, 5):
assert not N.all(rvals[i] == rvals[j])
def test_B(self):
"""Test that random numbers change from call to call!
Also, make sure that the seeding strategy doesn't change without failing a test.
Random numbers can't be too random or experiments aren't repeatable. Email theano-dev
before updating the `rvals` in this test.
"""
class B(RModule):
def __init__(self):
super(B, self).__init__()
self.x = compile.Member(tensor.dvector())
self.r = self.random.uniform(tensor.shape(self.x))
self.f = compile.Method([self.x], self.r)
class E(RModule):
def __init__(self):
super(E, self).__init__()
self.b = B()
self.f = compile.Method([self.b.x], self.b.r)
b = E()
m = b.make()
m.seed(1000)
#print m.f(N.ones(5))
#print m.f(N.ones(5))
#print m.f(N.ones(5))
rvals = ["0.74802375876 0.872308123517 0.294830748897 0.803123780003 0.6321109955",
"0.00168744844365 0.278638315678 0.725436793755 0.7788480779 0.629885140994",
"0.545561221664 0.0992011009108 0.847112593242 0.188015424144 0.158046201298",
"0.054382248842 0.563459168529 0.192757276954 0.360455221883 0.174805216702",
"0.961942907777 0.49657319422 0.0316111492826 0.0915054717012 0.195877184515"]
for i in xrange(5):
s = " ".join([str(n) for n in m.f(N.ones(5))])
print s
assert s == rvals[i]
if __name__ == '__main__':
from theano.tests import main
main("test_raw_random")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论