提交 0d070957 authored 作者: James Bergstra's avatar James Bergstra

added RandomStateType, comments, unit tests to raw_random

上级 070ae21e
......@@ -13,31 +13,122 @@ import sys
RS = numpy.random.RandomState
class RandomStateType(gof.Type):
"""A Type wrapper for numpy.RandomState
The reason this exists (and `Generic` doesn't suffice) is that RandomState objects that
would appear to be equal do not compare equal with the '==' operator. This Type exists to
provide an equals function that is used by DebugMode.
"""
def __str__(self):
return 'RandomStateType'
def filter(self, data, strict=False):
if self.is_valid_value(data):
return data
else:
raise TypeError()
def is_valid_value(self, a):
return type(a) == numpy.random.RandomState
def values_eq(self, a, b):
sa = a.get_state()
sb = b.get_state()
for aa, bb in zip(sa, sb):
if isinstance(aa, numpy.ndarray):
if not numpy.all(aa == bb):
return False
else:
if not aa == bb:
return False
return True
random_state_type = RandomStateType()
class RandomFunction(gof.Op):
"""Op that draws random numbers from a numpy.RandomState object
"""
def __init__(self, fn, outtype, *args, **kwargs):
"""
fn: a random function with the same signature as functions in numpy.random.RandomState
outtype: the type of the output
args: a list of default arguments for the function
kwargs: if the 'inplace' key is there, its value will be used to determine if the op operates inplace or not
:param fn: a member function of numpy.RandomState
Technically, any function with a signature like the ones in numpy.random.RandomState
will do. This function must accept the shape (sometimes called size) of the output as
the last positional argument.
:type fn: string or function reference. A string will be interpreted as the name of a
member function of numpy.random.RandomState.
:param outtype: the theano Type of the output
:param args: a list of default arguments for the function
:param kwargs: if the 'inplace' key is there, its value will be used to determine if the op operates inplace or not
"""
self.__setstate__([fn, outtype, args, kwargs])
def __eq__(self, other):
return type(self) == type(other) \
and self.fn == other.fn\
and self.outtype == other.outtype\
and self.args == other.args\
and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ hash(self.fn) \
^ hash(self.outtype) ^ hash(self.args) ^ hash(self.inplace)
def __getstate__(self):
return self.state
def __setstate__(self, state):
self.state = state
fn, outtype, args, kwargs = state
self.fn = getattr(RS, fn) if isinstance(fn, str) else fn
self.outtype = outtype
self.args = tuple(tensor.as_tensor(arg) for arg in args)
self.inplace = kwargs.pop('inplace', False)
if self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, r, shape, *args):
"""
in: r -> RandomState (gof.generic),
shape -> lvector
args -> the arguments expected by the numpy function
out: r2 -> the new RandomState (gof.generic)
out -> the random numbers we generated
:param r: a numpy.RandomState instance, or a Result of Type RandomStateType that will
contain a RandomState instance.
:param shape: an lvector with the shape of the tensor output by this Op. At runtime,
the value associated with this lvector must have a length that matches the number of
dimensions promised by `self.outtype`.
:param args: the values associated with these results will be passed to the RandomState
function during perform as extra "*args"-style arguments. These should be castable to
results of Type Tensor.
:rtype: Apply
:return: Apply with two outputs. The first output is a gof.generic Result from which
to draw further random numbers. The second output is the outtype() instance holding
the random draw.
"""
args = map(tensor.as_tensor, args)
if shape == () or shape == []:
shape = tensor.lvector()
else:
shape = tensor.as_tensor(shape)
assert shape.type == tensor.lvector
shape = tensor.as_tensor(shape, ndim=1)
#print 'SHAPE TYPE', shape.type, tensor.lvector
assert shape.type.ndim == 1
assert shape.type.dtype == 'int64'
if not isinstance(r.type, RandomStateType):
print >> sys.stderr, 'WARNING: RandomState instances should be in RandomStateType'
if 0:
raise TypeError('r must be RandomStateType instance', r)
# assert shape.type == tensor.lvector doesn't work because we want to ignore the
# broadcastable vector
assert len(args) <= len(self.args)
args += (None,) * (len(self.args) - len(args))
inputs = []
......@@ -51,6 +142,7 @@ class RandomFunction(gof.Op):
def perform(self, node, inputs, (rout, out)):
r, shape, args = inputs[0], inputs[1], inputs[2:]
assert type(r) == numpy.random.RandomState
r_orig = r
assert self.outtype.ndim == len(shape)
if not self.inplace:
......@@ -64,31 +156,7 @@ class RandomFunction(gof.Op):
out[0] = rval
def grad(self, inputs, outputs):
return [None] * len(inputs)
def __eq__(self, other):
return type(self) == type(other) \
and self.fn == other.fn\
and self.outtype == other.outtype\
and self.args == other.args\
and self.inplace == other.inplace
def __hash__(self):
return hash(self.fn) ^ hash(self.outtype) ^ hash(self.args) ^ hash(self.inplace)
def __getstate__(self):
return self.state
def __setstate__(self, state):
self.state = state
fn, outtype, args, kwargs = state
self.fn = getattr(RS, fn) if isinstance(fn, str) else fn
self.outtype = outtype
self.args = tuple(tensor.as_tensor(arg) for arg in args)
self.inplace = kwargs.pop('inplace', False)
if self.inplace:
self.destroy_map = {0: [0]}
return [None for i in inputs]
......@@ -131,7 +199,7 @@ def random_function(fn, dtype, *rfargs, **rfkwargs):
ndim = tensor.get_vector_length(shape)
if ndim is None:
raise ValueError('Cannot infer the number of dimensions from the shape argument.')
# note: rf should probably be cached for future use
# note: rf could be cached for future use
rf = RandomFunction(fn, tensor.Tensor(dtype = dtype, broadcastable = (False,)*ndim), *rfargs, **rfkwargs)
return rf(r, shape, *args, **kwargs)
return f
......
## TODO: REDO THESE TESTS
import sys
import unittest
import numpy as N
......@@ -9,6 +9,39 @@ from theano import tensor
from theano import compile, gof
class T_random_function(unittest.TestCase):
def test_basic_usage(self):
rf = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0)
assert not rf.inplace
assert getattr(rf, 'destroy_map', {}) == {}
rng_R = random_state_type()
print rng_R
post_r, out = rf(rng_R, (4,))
assert out.type == tensor.dvector
f = compile.function([rng_R], out)
rng_state0 = numpy.random.RandomState(55)
f_0 = f(rng_state0)
f_1 = f(rng_state0)
assert numpy.all(f_0 == f_1)
def test_inplace_norun(self):
rf = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0,
inplace=True)
assert rf.inplace
assert getattr(rf, 'destroy_map', {}) != {}
def test_inplace_optimization(self):
print >> sys.stderr, "WARNING NOT IMPLEMENTED T_random_function.test_inplace_optimization"
class T_test_module(unittest.TestCase):
def test_state_propagation(self):
x = tensor.vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论