提交 792e615a authored 作者: dima's avatar dima 提交者: serdyuk

Fixed formatting

Binary mode for zip Binary mode in tests
上级 839043bb
......@@ -136,7 +136,7 @@ class PersistentNdarrayID(object):
def _resolve_name(self, obj):
"""Determine the name the object should be saved under."""
name = 'array_{}'.format(self.count)
name = 'array_{0}'.format(self.count)
self.count += 1
return name
......@@ -147,7 +147,7 @@ class PersistentNdarrayID(object):
numpy.lib.format.write_array(f, obj)
name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'ndarray.{}'.format(name)
self.seen[id(obj)] = 'ndarray.{0}'.format(name)
return self.seen[id(obj)]
......@@ -164,7 +164,7 @@ class PersistentCudaNdarrayID(PersistentNdarrayID):
numpy.lib.format.write_array(f, numpy.asarray(obj))
name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'cuda_ndarray.{}'.format(name)
self.seen[id(obj)] = 'cuda_ndarray.{0}'.format(name)
return self.seen[id(obj)]
super(PersistentCudaNdarrayID, self).__call__(obj)
......@@ -206,8 +206,8 @@ class PersistentSharedVariableID(PersistentCudaNdarrayID):
if count:
if not self.allow_duplicates:
raise ValueError("multiple shared variables with the name "
"`{}` found".format(name))
name = '{}_{}'.format(name, count + 1)
"`{0}` found".format(name))
name = '{0}_{1}'.format(name, count + 1)
self.name_counter[name] += 1
return name
return super(PersistentSharedVariableID, self)._resolve_name(obj)
......@@ -219,7 +219,7 @@ class PersistentSharedVariableID(PersistentCudaNdarrayID):
ValueError("can't pickle shared variable with name `pkl`")
self.ndarray_names[id(obj.container.storage[0])] = obj.name
elif not self.allow_unnamed:
raise ValueError("unnamed shared variable, {}".format(obj))
raise ValueError("unnamed shared variable, {0}".format(obj))
return super(PersistentSharedVariableID, self).__call__(obj)
......@@ -296,7 +296,7 @@ def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
array(2)
"""
with closing(zipfile.ZipFile(file_handler, 'w', zipfile.ZIP_DEFLATED,
with closing(zipfile.ZipFile(file_handler, 'wb', zipfile.ZIP_DEFLATED,
allowZip64=True)) as zip_file:
def func(f):
p = pickle.Pickler(f, protocol=protocol)
......@@ -317,7 +317,7 @@ def load(f, persistent_load=PersistentNdarrayLoad):
:type persistent_load: callable, optional
"""
with closing(zipfile.ZipFile(f, 'r')) as zip_file:
with closing(zipfile.ZipFile(f, 'rb')) as zip_file:
p = pickle.Unpickler(StringIO(zip_file.open('pkl').read()))
p.persistent_load = persistent_load(zip_file)
return p.load()
......
......@@ -18,10 +18,10 @@ def test_dump_load():
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
[[1]], False)
with open('test', 'w') as f:
with open('test', 'wb') as f:
dump(x, f)
with open('test', 'r') as f:
with open('test', 'rb') as f:
x = load(f)
assert x.name == 'x'
......@@ -34,10 +34,10 @@ def test_dump_load_mrg():
rng = MRG_RandomStreams(use_cuda=True)
with open('test', 'w') as f:
with open('test', 'wb') as f:
dump(rng, f)
with open('test', 'r') as f:
with open('test', 'rb') as f:
rng = load(f)
assert type(rng) == MRG_RandomStreams
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论