提交 9c57a03c authored 作者: texot's avatar texot

Fix PersistentNdarrayLoad problem restoring same object twice

上级 1efb1539
...@@ -76,7 +76,7 @@ class SharedVariable(Variable): ...@@ -76,7 +76,7 @@ class SharedVariable(Variable):
raise TypeError('value and strict are ignored if you pass ' raise TypeError('value and strict are ignored if you pass '
'a container here') 'a container here')
else: else:
if container is not None: if value is not None:
raise TypeError('Error to specify both value and container') raise TypeError('Error to specify both value and container')
self.container = Container( self.container = Container(
self, self,
......
...@@ -275,10 +275,14 @@ class PersistentNdarrayLoad(object): ...@@ -275,10 +275,14 @@ class PersistentNdarrayLoad(object):
""" """
def __init__(self, zip_file): def __init__(self, zip_file):
self.zip_file = zip_file self.zip_file = zip_file
self.cache = {}
def __call__(self, persid): def __call__(self, persid):
array_type, name = persid.split('.') array_type, name = persid.split('.')
if name in self.cache:
return self.cache[name]
ret = None
array = numpy.lib.format.read_array(self.zip_file.open(name)) array = numpy.lib.format.read_array(self.zip_file.open(name))
if array_type == 'cuda_ndarray': if array_type == 'cuda_ndarray':
if config.experimental.unpickle_gpu_on_cpu: if config.experimental.unpickle_gpu_on_cpu:
...@@ -286,14 +290,16 @@ class PersistentNdarrayLoad(object): ...@@ -286,14 +290,16 @@ class PersistentNdarrayLoad(object):
warnings.warn("config.experimental.unpickle_gpu_on_cpu is set " warnings.warn("config.experimental.unpickle_gpu_on_cpu is set "
"to True. Unpickling CudaNdarray as " "to True. Unpickling CudaNdarray as "
"numpy.ndarray") "numpy.ndarray")
return array ret = array
elif cuda_ndarray: elif cuda_ndarray:
return cuda_ndarray.cuda_ndarray.CudaNdarray(array) ret = cuda_ndarray.cuda_ndarray.CudaNdarray(array)
else: else:
raise ImportError("Cuda not found. Cannot unpickle " raise ImportError("Cuda not found. Cannot unpickle "
"CudaNdarray") "CudaNdarray")
else: else:
return array ret = array
self.cache[name] = ret
return ret
def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL, def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论