提交 acd4da11 authored 作者: dima's avatar dima 提交者: serdyuk

Test names in a zip file produced by dump

上级 9cdbcdb4
import numpy
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano
import theano.sandbox.cuda as cuda_ndarray import theano.sandbox.cuda as cuda_ndarray
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
...@@ -8,11 +10,11 @@ from theano.sandbox.cuda.var import CudaNdarraySharedVariable ...@@ -8,11 +10,11 @@ from theano.sandbox.cuda.var import CudaNdarraySharedVariable
from theano.sandbox.rng_mrg import MRG_RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.misc.pkl_utils import dump, load from theano.misc.pkl_utils import dump, load
if not cuda_ndarray.cuda_enabled:
raise SkipTest('Optional package cuda disabled')
def test_dump_load(): def test_dump_load():
if not cuda_ndarray.cuda_enabled:
raise SkipTest('Optional package cuda disabled')
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'), x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
[[1]], False) [[1]], False)
...@@ -27,6 +29,9 @@ def test_dump_load(): ...@@ -27,6 +29,9 @@ def test_dump_load():
def test_dump_load_mrg(): def test_dump_load_mrg():
if not cuda_ndarray.cuda_enabled:
raise SkipTest('Optional package cuda disabled')
rng = MRG_RandomStreams(use_cuda=True) rng = MRG_RandomStreams(use_cuda=True)
with open('test', 'w') as f: with open('test', 'w') as f:
...@@ -36,3 +41,17 @@ def test_dump_load_mrg(): ...@@ -36,3 +41,17 @@ def test_dump_load_mrg():
rng = load(f) rng = load(f)
assert type(rng) == MRG_RandomStreams assert type(rng) == MRG_RandomStreams
def test_dump_zip_names():
foo_1 = theano.shared(0, name='foo')
foo_2 = theano.shared(1, name='foo')
with open('model.zip', 'w') as f:
dump((foo_1, foo_2, numpy.array(2)), f)
keys = numpy.load('model.zip').keys()
assert keys == ['foo', 'foo_2', 'array_0', 'pkl']
foo = numpy.load('model.zip')['foo']
assert foo == numpy.array(0)
with open('model.zip') as f:
foo_1, foo_2, array = load(f)
assert array == numpy.array(2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论