提交 792e615a authored 作者: dima's avatar dima 提交者: serdyuk

Fixed formatting

Binary mode for zip Binary mode in tests
上级 839043bb
...@@ -136,7 +136,7 @@ class PersistentNdarrayID(object): ...@@ -136,7 +136,7 @@ class PersistentNdarrayID(object):
def _resolve_name(self, obj): def _resolve_name(self, obj):
"""Determine the name the object should be saved under.""" """Determine the name the object should be saved under."""
name = 'array_{}'.format(self.count) name = 'array_{0}'.format(self.count)
self.count += 1 self.count += 1
return name return name
...@@ -147,7 +147,7 @@ class PersistentNdarrayID(object): ...@@ -147,7 +147,7 @@ class PersistentNdarrayID(object):
numpy.lib.format.write_array(f, obj) numpy.lib.format.write_array(f, obj)
name = self._resolve_name(obj) name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name) zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'ndarray.{}'.format(name) self.seen[id(obj)] = 'ndarray.{0}'.format(name)
return self.seen[id(obj)] return self.seen[id(obj)]
...@@ -164,7 +164,7 @@ class PersistentCudaNdarrayID(PersistentNdarrayID): ...@@ -164,7 +164,7 @@ class PersistentCudaNdarrayID(PersistentNdarrayID):
numpy.lib.format.write_array(f, numpy.asarray(obj)) numpy.lib.format.write_array(f, numpy.asarray(obj))
name = self._resolve_name(obj) name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name) zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'cuda_ndarray.{}'.format(name) self.seen[id(obj)] = 'cuda_ndarray.{0}'.format(name)
return self.seen[id(obj)] return self.seen[id(obj)]
super(PersistentCudaNdarrayID, self).__call__(obj) super(PersistentCudaNdarrayID, self).__call__(obj)
...@@ -206,8 +206,8 @@ class PersistentSharedVariableID(PersistentCudaNdarrayID): ...@@ -206,8 +206,8 @@ class PersistentSharedVariableID(PersistentCudaNdarrayID):
if count: if count:
if not self.allow_duplicates: if not self.allow_duplicates:
raise ValueError("multiple shared variables with the name " raise ValueError("multiple shared variables with the name "
"`{}` found".format(name)) "`{0}` found".format(name))
name = '{}_{}'.format(name, count + 1) name = '{0}_{1}'.format(name, count + 1)
self.name_counter[name] += 1 self.name_counter[name] += 1
return name return name
return super(PersistentSharedVariableID, self)._resolve_name(obj) return super(PersistentSharedVariableID, self)._resolve_name(obj)
...@@ -219,7 +219,7 @@ class PersistentSharedVariableID(PersistentCudaNdarrayID): ...@@ -219,7 +219,7 @@ class PersistentSharedVariableID(PersistentCudaNdarrayID):
ValueError("can't pickle shared variable with name `pkl`") ValueError("can't pickle shared variable with name `pkl`")
self.ndarray_names[id(obj.container.storage[0])] = obj.name self.ndarray_names[id(obj.container.storage[0])] = obj.name
elif not self.allow_unnamed: elif not self.allow_unnamed:
raise ValueError("unnamed shared variable, {}".format(obj)) raise ValueError("unnamed shared variable, {0}".format(obj))
return super(PersistentSharedVariableID, self).__call__(obj) return super(PersistentSharedVariableID, self).__call__(obj)
...@@ -296,7 +296,7 @@ def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL, ...@@ -296,7 +296,7 @@ def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
array(2) array(2)
""" """
with closing(zipfile.ZipFile(file_handler, 'w', zipfile.ZIP_DEFLATED, with closing(zipfile.ZipFile(file_handler, 'wb', zipfile.ZIP_DEFLATED,
allowZip64=True)) as zip_file: allowZip64=True)) as zip_file:
def func(f): def func(f):
p = pickle.Pickler(f, protocol=protocol) p = pickle.Pickler(f, protocol=protocol)
...@@ -317,7 +317,7 @@ def load(f, persistent_load=PersistentNdarrayLoad): ...@@ -317,7 +317,7 @@ def load(f, persistent_load=PersistentNdarrayLoad):
:type persistent_load: callable, optional :type persistent_load: callable, optional
""" """
with closing(zipfile.ZipFile(f, 'r')) as zip_file: with closing(zipfile.ZipFile(f, 'rb')) as zip_file:
p = pickle.Unpickler(StringIO(zip_file.open('pkl').read())) p = pickle.Unpickler(StringIO(zip_file.open('pkl').read()))
p.persistent_load = persistent_load(zip_file) p.persistent_load = persistent_load(zip_file)
return p.load() return p.load()
......
...@@ -18,10 +18,10 @@ def test_dump_load(): ...@@ -18,10 +18,10 @@ def test_dump_load():
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'), x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
[[1]], False) [[1]], False)
with open('test', 'w') as f: with open('test', 'wb') as f:
dump(x, f) dump(x, f)
with open('test', 'r') as f: with open('test', 'rb') as f:
x = load(f) x = load(f)
assert x.name == 'x' assert x.name == 'x'
...@@ -34,10 +34,10 @@ def test_dump_load_mrg(): ...@@ -34,10 +34,10 @@ def test_dump_load_mrg():
rng = MRG_RandomStreams(use_cuda=True) rng = MRG_RandomStreams(use_cuda=True)
with open('test', 'w') as f: with open('test', 'wb') as f:
dump(rng, f) dump(rng, f)
with open('test', 'r') as f: with open('test', 'rb') as f:
rng = load(f) rng = load(f)
assert type(rng) == MRG_RandomStreams assert type(rng) == MRG_RandomStreams
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论