提交 1d7141d1 authored 作者: serdyuk's avatar serdyuk

Zip files not in binary format

Binary mode for test files Added import Reading in binary mode Used BytesIO
上级 533d889b
......@@ -5,6 +5,7 @@ These pickled graphs can be used, for instance, as cases for
unit tests or regression tests.
"""
import numpy
import os
import pickle
import sys
import tempfile
......@@ -13,7 +14,7 @@ import warnings
from collections import defaultdict
from contextlib import closing
from pickle import HIGHEST_PROTOCOL
from theano.compat.six import StringIO
from theano.compat.six import BytesIO
try:
from pickle import DEFAULT_PROTOCOL
except ImportError:
......@@ -296,7 +297,7 @@ def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
array(2)
"""
with closing(zipfile.ZipFile(file_handler, 'wb', zipfile.ZIP_DEFLATED,
with closing(zipfile.ZipFile(file_handler, 'w', zipfile.ZIP_DEFLATED,
allowZip64=True)) as zip_file:
def func(f):
p = pickle.Pickler(f, protocol=protocol)
......@@ -317,8 +318,8 @@ def load(f, persistent_load=PersistentNdarrayLoad):
:type persistent_load: callable, optional
"""
with closing(zipfile.ZipFile(f, 'rb')) as zip_file:
p = pickle.Unpickler(StringIO(zip_file.open('pkl').read()))
with closing(zipfile.ZipFile(f, 'r')) as zip_file:
p = pickle.Unpickler(BytesIO(zip_file.open('pkl').read()))
p.persistent_load = persistent_load(zip_file)
return p.load()
......
......@@ -46,12 +46,12 @@ def test_dump_load_mrg():
def test_dump_zip_names():
foo_1 = theano.shared(0, name='foo')
foo_2 = theano.shared(1, name='foo')
with open('model.zip', 'w') as f:
with open('model.zip', 'wb') as f:
dump((foo_1, foo_2, numpy.array(2)), f)
keys = numpy.load('model.zip').keys()
assert keys == ['foo', 'foo_2', 'array_0', 'pkl']
foo = numpy.load('model.zip')['foo']
assert foo == numpy.array(0)
with open('model.zip') as f:
with open('model.zip', 'rb') as f:
foo_1, foo_2, array = load(f)
assert array == numpy.array(2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论