提交 16acde8e authored 作者: serdyuk's avatar serdyuk

Added test for MRG_RandomStream dump

上级 9ea27947
...@@ -156,7 +156,6 @@ class PersistentCudaNdarrayID(PersistentNdarrayID): ...@@ -156,7 +156,6 @@ class PersistentCudaNdarrayID(PersistentNdarrayID):
super(PersistentCudaNdarrayID, self).__init__(zip_file) super(PersistentCudaNdarrayID, self).__init__(zip_file)
def __call__(self, obj): def __call__(self, obj):
print 'try cuda'
if (cuda_ndarray is not None and if (cuda_ndarray is not None and
type(obj) is cuda_ndarray.cuda_ndarray.CudaNdarray): type(obj) is cuda_ndarray.cuda_ndarray.CudaNdarray):
print 'cuda' print 'cuda'
......
...@@ -5,6 +5,7 @@ import theano.sandbox.cuda as cuda_ndarray ...@@ -5,6 +5,7 @@ import theano.sandbox.cuda as cuda_ndarray
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import CudaNdarraySharedVariable from theano.sandbox.cuda.var import CudaNdarraySharedVariable
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.misc.pkl_utils import dump, load from theano.misc.pkl_utils import dump, load
if not cuda_ndarray.cuda_enabled: if not cuda_ndarray.cuda_enabled:
...@@ -23,3 +24,16 @@ def test_dump_load(): ...@@ -23,3 +24,16 @@ def test_dump_load():
assert x.name == 'x' assert x.name == 'x'
assert_allclose(x.get_value(), [[1]]) assert_allclose(x.get_value(), [[1]])
def test_dump_load_mrg():
rng = MRG_RandomStreams(use_cuda=True)
with open('test', 'w') as f:
dump(rng, f)
with open('test', 'r') as f:
rng = load(f)
assert type(rng) == MRG_RandomStreams
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论