提交 31450d80 authored 作者: Bart van Merriënboer's avatar Bart van Merriënboer

Merge pull request #2980 from Theano/pkl_fix

Fix error and remove unnecessary constructor
...@@ -153,9 +153,6 @@ class PersistentNdarrayID(object): ...@@ -153,9 +153,6 @@ class PersistentNdarrayID(object):
class PersistentCudaNdarrayID(PersistentNdarrayID): class PersistentCudaNdarrayID(PersistentNdarrayID):
def __init__(self, zip_file):
super(PersistentCudaNdarrayID, self).__init__(zip_file)
def __call__(self, obj): def __call__(self, obj):
if (cuda_ndarray is not None and if (cuda_ndarray is not None and
type(obj) is cuda_ndarray.cuda_ndarray.CudaNdarray): type(obj) is cuda_ndarray.cuda_ndarray.CudaNdarray):
...@@ -203,12 +200,12 @@ class PersistentSharedVariableID(PersistentCudaNdarrayID): ...@@ -203,12 +200,12 @@ class PersistentSharedVariableID(PersistentCudaNdarrayID):
if id(obj) in self.ndarray_names: if id(obj) in self.ndarray_names:
name = self.ndarray_names[id(obj)] name = self.ndarray_names[id(obj)]
count = self.name_counter[name] count = self.name_counter[name]
self.name_counter[name] += 1
if count: if count:
if not self.allow_duplicates: if not self.allow_duplicates:
raise ValueError("multiple shared variables with the name " raise ValueError("multiple shared variables with the name "
"`{0}` found".format(name)) "`{0}` found".format(name))
name = '{0}_{1}'.format(name, count + 1) name = '{0}_{1}'.format(name, count + 1)
self.name_counter[name] += 1
return name return name
return super(PersistentSharedVariableID, self)._resolve_name(obj) return super(PersistentSharedVariableID, self)._resolve_name(obj)
......
...@@ -43,12 +43,13 @@ def test_dump_load_mrg(): ...@@ -43,12 +43,13 @@ def test_dump_load_mrg():
def test_dump_zip_names(): def test_dump_zip_names():
foo_1 = theano.shared(0, name='foo') foo_1 = theano.shared(0, name='foo')
foo_2 = theano.shared(1, name='foo') foo_2 = theano.shared(1, name='foo')
foo_3 = theano.shared(2, name='foo')
with open('model.zip', 'wb') as f: with open('model.zip', 'wb') as f:
dump((foo_1, foo_2, numpy.array(2)), f) dump((foo_1, foo_2, foo_3, numpy.array(3)), f)
keys = numpy.load('model.zip').keys() keys = numpy.load('model.zip').keys()
assert keys == ['foo', 'foo_2', 'array_0', 'pkl'] assert keys == ['foo', 'foo_2', 'foo_3', 'array_0', 'pkl']
foo = numpy.load('model.zip')['foo'] foo_3 = numpy.load('model.zip')['foo_3']
assert foo == numpy.array(0) assert foo_3 == numpy.array(2)
with open('model.zip', 'rb') as f: with open('model.zip', 'rb') as f:
foo_1, foo_2, array = load(f) foo_1, foo_2, foo_3, array = load(f)
assert array == numpy.array(2) assert array == numpy.array(3)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论