提交 7d72fa72 authored 作者: James Bergstra's avatar James Bergstra

some docs for tensor_random

上级 687bd55b
...@@ -30,8 +30,8 @@ class T_Random(unittest.TestCase): ...@@ -30,8 +30,8 @@ class T_Random(unittest.TestCase):
def test1(self): def test1(self):
rng = RandomState(12345) rng = RandomState(12345)
f0 = compile.function([], [rng.uniform((3,))]) f0 = compile.function([], [rng.gen('uniform', (3,))])
f1 = compile.function([], [rng.uniform((3,))]) f1 = compile.function([], [rng.gen('uniform', (3,))])
v0, v1 = f0(), f1() v0, v1 = f0(), f1()
...@@ -52,7 +52,7 @@ class T_Random(unittest.TestCase): ...@@ -52,7 +52,7 @@ class T_Random(unittest.TestCase):
def test3(self): def test3(self):
rng = RandomState(12345) rng = RandomState(12345)
template = tensor.fmatrix() template = tensor.fmatrix()
f0 = compile.function([template], [rng.uniform_like(template)]) f0 = compile.function([template], [rng.gen_like('uniform', template)])
v0 = f0(numpy.zeros((2,3))) v0 = f0(numpy.zeros((2,3)))
self.failUnless(str(v0[1,2]).startswith('0.595544')) self.failUnless(str(v0[1,2]).startswith('0.595544'))
......
"""Random number generation for Theano graphs."""
import gof import gof
import tensor import tensor
import numpy import numpy
import functools import functools
# the optional argument implements a closure
# the cache is used so that we we can be sure that
# id(self.fn) in NumpyGenerator identifies
# the computation performed.
def fn_from_dist(dist, cache={}):
if callable(dist):
return dist
if isinstance(dist, str):
return getattr(numpy.random.RandomState, dist)
name, kwargs = dist
key = (name, tuple(kwargs.items()))
if key not in cache:
fn = getattr(numpy.random.RandomState, name)
fn = functools.partial(fn, **kwargs)
cache[key] = fn
return cache[key]
class RandomState(object): class RandomState(object):
"""The Theano version of numpy.RandomState
This class generates a sequence of L{Op} instances via the gen() and
gen_like() methods.
@ivar seed: an integer which determines the initial state of the L{Op}
instances returned by gen(), gen_like()
@type seed: int
"""
@staticmethod
def _fn_from_dist(dist, cache={}):
"""Return a function from a distribution description
@param dist: identifier of a sampling distribution.
@type dist: callable or str or tuple(str, dict)
@param cache: The optional cache argument implements a closure, which ensures that
multiple requests for the same sampling function will get the same
sampling function. L{NumpyGenerator}.__hash__ depends on this.
@type cache: dict
"""
if callable(dist):
return dist
if isinstance(dist, str):
return getattr(numpy.random.RandomState, dist)
name, kwargs = dist
key = (name, tuple(kwargs.items()))
if key not in cache:
fn = getattr(numpy.random.RandomState, name)
fn = functools.partial(fn, **kwargs)
cache[key] = fn
return cache[key]
def __init__(self, seed): def __init__(self, seed):
self.seed = seed self.seed = seed
def uniform(self, shape, ndim=None):
return self.gen('uniform', shape, ndim)
def uniform_like(self, x):
return self.gen_like('uniform', x)
def gen(self, dist, shape=(), ndim=None): def gen(self, dist, shape=(), ndim=None):
"""
@param dist: identifier of a sampling distribution. See L{_fn_from_dist}.
@param shape: tuple
@return: A tensor of random numbers, with given shape.
@rtype: L{Result} (output of L{Apply} of L{NumpyGenerator} instance)
"""
self.seed += 1 self.seed += 1
fn = fn_from_dist(dist) fn = RandomState._fn_from_dist(dist)
if isinstance(shape, tuple): if isinstance(shape, tuple):
return NumpyGenerator(self.seed-1, len(shape),fn) (shape) return NumpyGenerator(self.seed-1, len(shape),fn) (shape)
return NumpyGenerator(self.seed - 1, ndim, fn)(shape) return NumpyGenerator(self.seed - 1, ndim, fn)(shape)
def gen_like(self, dist, x): def gen_like(self, dist, x):
"""
@param dist: identifier of a sampling distribution. See L{_fn_from_dist}.
@param x: L{Result} of type L{Tensor}
@return: A tensor of random numbers, with the same shape as x.
@rtype: L{Result} (output of L{Apply} of L{NumpyGenerator} instance)
"""
self.seed += 1 self.seed += 1
fn = fn_from_dist(dist) fn = RandomState._fn_from_dist(dist)
return NumpyGenerator(self.seed-1, x.type.ndim, fn)(tensor.shape(x)) return NumpyGenerator(self.seed-1, x.type.ndim, fn)(tensor.shape(x))
class NumpyGenerator(gof.op.Op): class NumpyGenerator(gof.op.Op):
"""Supply a sequence of random tensors of a given shape, from a given
distribution.
@param seed: initial state for instances of this L{Op}.
@type seed: anything that numpy.random.RandomState accepts.
@param ndim: the rank of random tensors produced by this op.
@type ndim: non-negative integer
@param fn: a sampling function
@type fn: a callable that can reply to fn(numpy.RandomState(), size=<tuple>)
"""
destroy_map = {0: [0]} destroy_map = {0: [0]}
def __init__(self, seed, ndim, fn, **kwargs): def __init__(self, seed, ndim, fn, **kwargs):
...@@ -51,6 +87,9 @@ class NumpyGenerator(gof.op.Op): ...@@ -51,6 +87,9 @@ class NumpyGenerator(gof.op.Op):
self.seed = seed self.seed = seed
self.ndim = ndim self.ndim = ndim
self.fn = fn self.fn = fn
assert numpy.random.RandomState(seed) #test the seed
assert 'int' in str(type(ndim))
assert callable(self.fn)
def __eq__(self, other): def __eq__(self, other):
return (type(self) is type(other))\ return (type(self) is type(other))\
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论