提交 eab932a1 authored 作者: james@X40's avatar james@X40

disabling RModule tests

上级 d7aebfad
...@@ -43,7 +43,7 @@ from compile import \ ...@@ -43,7 +43,7 @@ from compile import \
Mode, \ Mode, \
predefined_modes, predefined_linkers, predefined_optimizers, \ predefined_modes, predefined_linkers, predefined_optimizers, \
FunctionMaker, function, OpFromGraph, \ FunctionMaker, function, OpFromGraph, \
Component, External, Member, KitComponent, Method, \ Component, External, Member, Method, \
Composite, ComponentList, ComponentDict, Module, \ Composite, ComponentList, ComponentDict, Module, \
ProfileMode ProfileMode
......
...@@ -4,127 +4,14 @@ import sys ...@@ -4,127 +4,14 @@ import sys
import unittest import unittest
import numpy as N import numpy as N
from theano.tensor.rmodule import * from theano.tensor.deprecated.rmodule import *
from theano import tensor from theano import tensor
from theano import compile, gof from theano import compile, gof
class T_RandomStreams(unittest.TestCase): if 0:
def test_basics(self): class T_test_module(unittest.TestCase):
m = Module()
m.random = RandomStreams(234)
m.fn = Method([], m.random.uniform((2,2)))
m.gn = Method([], m.random.normal((2,2)))
made = m.make()
made.random.initialize()
fn_val0 = made.fn()
fn_val1 = made.fn()
gn_val0 = made.gn()
rng_seed = numpy.random.RandomState(234).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
#print numpy_val0
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_seed_in_initialize(self):
m = Module()
m.random = RandomStreams(234)
m.fn = Method([], m.random.uniform((2,2)))
made = m.make()
made.random.initialize(seed=888)
fn_val0 = made.fn()
fn_val1 = made.fn()
rng_seed = numpy.random.RandomState(888).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
#print numpy_val0
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_seed_fn(self):
m = Module()
m.random = RandomStreams(234)
m.fn = Method([], m.random.uniform((2,2)))
made = m.make()
made.random.initialize(seed=789)
made.random.seed(888)
fn_val0 = made.fn()
fn_val1 = made.fn()
rng_seed = numpy.random.RandomState(888).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
#print numpy_val0
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_getitem(self):
m = Module()
m.random = RandomStreams(234)
out = m.random.uniform((2,2))
m.fn = Method([], out)
made = m.make()
made.random.initialize(seed=789)
made.random.seed(888)
rng = numpy.random.RandomState()
rng.set_state(made.random[out.rng].get_state())
fn_val0 = made.fn()
fn_val1 = made.fn()
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_setitem(self):
m = Module()
m.random = RandomStreams(234)
out = m.random.uniform((2,2))
m.fn = Method([], out)
made = m.make()
made.random.initialize(seed=789)
made.random.seed(888)
rng = numpy.random.RandomState(823874)
made.random[out.rng] = numpy.random.RandomState(823874)
fn_val0 = made.fn()
fn_val1 = made.fn()
numpy_val0 = rng.uniform(size=(2,2))
numpy_val1 = rng.uniform(size=(2,2))
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
class T_test_module(unittest.TestCase):
def test_state_propagation(self): def test_state_propagation(self):
if 1: if 1:
print >> sys.stderr, "RModule deprecated" print >> sys.stderr, "RModule deprecated"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论