提交 1de3484e authored 作者: James Bergstra's avatar James Bergstra

removed random_function from raw_random

上级 019b7039
......@@ -5,6 +5,7 @@ import numpy as N
from theano.tests import unittest_tools
from theano.tensor.raw_random import *
from theano.tensor import raw_random
from theano import tensor
......@@ -12,7 +13,7 @@ from theano import compile, gof
class T_random_function(unittest.TestCase):
def test_basic_usage(self):
rf = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0)
rf = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector)
assert not rf.inplace
assert getattr(rf, 'destroy_map', {}) == {}
......@@ -32,23 +33,21 @@ class T_random_function(unittest.TestCase):
assert numpy.all(f_0 == f_1)
def test_inplace_norun(self):
rf = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0,
inplace=True)
rf = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, inplace=True)
assert rf.inplace
assert getattr(rf, 'destroy_map', {}) != {}
def test_args(self):
"""Test that arguments to RandomFunction are honored"""
rf2 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0)
rf4 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -4.0, 4.0,
inplace=True)
rf2 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector)
rf4 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, inplace=True)
rng_R = random_state_type()
# use make_node to override some of the self.args
post_r2, out2 = rf2(rng_R, (4,))
post_r2_4, out2_4 = rf2(rng_R, (4,), -4.0)
post_r2, out2 = rf2(rng_R, (4,), -2, 2)
post_r2_4, out2_4 = rf2(rng_R, (4,), -4.0, 2)
post_r2_4_4, out2_4_4 = rf2(rng_R, (4,), -4.0, 4.0)
post_r4, out4 = rf4(rng_R, (4,))
post_r4, out4 = rf4(rng_R, (4,), -4, 4)
f = compile.function(
[compile.In(rng_R, value=numpy.random.RandomState(55), update=post_r4, mutable=True)],
......@@ -65,7 +64,7 @@ class T_random_function(unittest.TestCase):
def test_inplace_optimization(self):
"""Test that FAST_RUN includes the random_make_inplace optimization"""
#inplace = False
rf2 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector, -2.0, 2.0)
rf2 = RandomFunction(numpy.random.RandomState.uniform, tensor.dvector)
rng_R = random_state_type()
# use make_node to override some of the self.args
......@@ -92,19 +91,18 @@ class T_random_function(unittest.TestCase):
def test_random_function_ndim(self):
"""Test that random_function helper function accepts ndim as first argument"""
rf2 = random_function(numpy.random.RandomState.uniform, 'float64', -2.0, 2.0)
rng_R = random_state_type()
# ndim is an optional argument indicating the length of the 'shape'
# ndim not specified, OK
post_out4, out4 = rf2(rng_R, (4,))
post_out4, out4 = uniform(rng_R, (4,))
# ndim specified, consistent with shape, OK
post_out1_4, out1_4 = rf2(rng_R, 1, (4,))
post_out2_4_4, out2_4_4= rf2(rng_R, 2, (4, 4))
post_out1_4, out1_4 = uniform(rng_R, (4,), ndim=1)
post_out2_4_4, out2_4_4= uniform(rng_R, (4, 4), ndim=2)
# ndim specified, but not compatible with shape
post_out2_4, out2_4 = rf2(rng_R, 2, (4,))
post_out2_4, out2_4 = uniform(rng_R, (4,), ndim=2)
f_ok = compile.function(
[compile.In(rng_R, value=numpy.random.RandomState(55), update=post_out2_4_4, mutable=True)],
......@@ -132,18 +130,31 @@ class T_random_function(unittest.TestCase):
# Specifying a different ndim_added will change the Op's output ndim,
# so numpy.uniform will produce a result of incorrect shape,
# and a ValueError should be raised.
uni_1 = random_function(numpy.random.RandomState.uniform, 'float64', -2.0, 2.0, ndim_added=1)
uni_0 = random_function(numpy.random.RandomState.uniform, 'float64', -2.0, 2.0, ndim_added=0)
uni_m1 = random_function(numpy.random.RandomState.uniform, 'float64', -2.0, 2.0, ndim_added=-1)
def ndim_added_deco(ndim_added):
def randomfunction(random_state, size=(), low=0.0, high=0.0, ndim=None):
ndim, size = raw_random._infer_ndim(ndim, size)
op = RandomFunction('uniform',
tensor.TensorType(dtype = 'float64', broadcastable =
(False,)*(ndim+ndim_added)),
ndim_added=ndim_added)
return op(random_state, size, low, high)
return randomfunction
uni_1 = ndim_added_deco(1)
uni_0 = ndim_added_deco(0)
uni_m1 = ndim_added_deco(-1)
#uni_1 = random_function(numpy.random.RandomState.uniform, 'float64', -2.0, 2.0, ndim_added=1)
#uni_0 = random_function(numpy.random.RandomState.uniform, 'float64', -2.0, 2.0, ndim_added=0)
#uni_m1 = random_function(numpy.random.RandomState.uniform, 'float64', -2.0, 2.0, ndim_added=-1)
rng_R = random_state_type()
p_uni11, uni11 = uni_1(rng_R, 1, (4,))
p_uni12, uni12 = uni_1(rng_R, 2, (3,4))
p_uni01, uni01 = uni_0(rng_R, 1, (4,))
p_uni02, uni02 = uni_0(rng_R, 2, (3,4))
p_unim11, unim11 = uni_m1(rng_R, 1, (4,))
p_unim12, unim12 = uni_m1(rng_R, 2, (3,4))
p_uni11, uni11 = uni_1(rng_R, size=(4,))
p_uni12, uni12 = uni_1(rng_R, size=(3,4))
p_uni01, uni01 = uni_0(rng_R, size=(4,))
p_uni02, uni02 = uni_0(rng_R, size=(3,4))
p_unim11, unim11 = uni_m1(rng_R, size=(4,))
p_unim12, unim12 = uni_m1(rng_R, size=(3,4))
self.assertEqual(uni11.ndim, 2)
self.assertEqual(uni12.ndim, 3)
......@@ -320,7 +331,8 @@ class T_random_function(unittest.TestCase):
def test_permutation(self):
"""Test that raw_random.permutation generates the same results as numpy."""
rng_R = random_state_type()
post_r, out = permutation(rng_R, (9,), 6)
post_r, out = permutation(rng_R, size=(9,), n=6)
print 'OUT NDIM', out.ndim
f = compile.function(
[compile.In(rng_R, value=numpy.random.RandomState(55), update=post_r, mutable=True)],
[out], accept_inplace=True)
......@@ -365,6 +377,24 @@ class T_random_function(unittest.TestCase):
self.assertTrue(val0.shape == (7,3,5))
self.assertTrue(val1.shape == (7,3,5))
def test_symbolic_shape(self):
rng_R = random_state_type()
shape = tensor.lvector()
post_r, out = uniform(rng_R, shape, ndim=2)
f = compile.function([rng_R, shape], out)
rng_state0 = numpy.random.RandomState(55)
assert f(rng_state0, [2,3]).shape == (2,3)
assert f(rng_state0, [4,8]).shape == (4,8)
self.assertRaises(ValueError, f, rng_state0, [4])
self.assertRaises(ValueError, f, rng_state0, [4,3,4,5])
if __name__ == '__main__':
from theano.tests import main
main("test_raw_random")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论