提交 1d7141d1 authored 作者: serdyuk's avatar serdyuk

Zip files not in binary format

Binary mode for test files Added import Reading in binary mode Used BytesIO
上级 533d889b
...@@ -5,6 +5,7 @@ These pickled graphs can be used, for instance, as cases for ...@@ -5,6 +5,7 @@ These pickled graphs can be used, for instance, as cases for
unit tests or regression tests. unit tests or regression tests.
""" """
import numpy import numpy
import os
import pickle import pickle
import sys import sys
import tempfile import tempfile
...@@ -13,7 +14,7 @@ import warnings ...@@ -13,7 +14,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import closing from contextlib import closing
from pickle import HIGHEST_PROTOCOL from pickle import HIGHEST_PROTOCOL
from theano.compat.six import StringIO from theano.compat.six import BytesIO
try: try:
from pickle import DEFAULT_PROTOCOL from pickle import DEFAULT_PROTOCOL
except ImportError: except ImportError:
...@@ -296,7 +297,7 @@ def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL, ...@@ -296,7 +297,7 @@ def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
array(2) array(2)
""" """
with closing(zipfile.ZipFile(file_handler, 'wb', zipfile.ZIP_DEFLATED, with closing(zipfile.ZipFile(file_handler, 'w', zipfile.ZIP_DEFLATED,
allowZip64=True)) as zip_file: allowZip64=True)) as zip_file:
def func(f): def func(f):
p = pickle.Pickler(f, protocol=protocol) p = pickle.Pickler(f, protocol=protocol)
...@@ -317,8 +318,8 @@ def load(f, persistent_load=PersistentNdarrayLoad): ...@@ -317,8 +318,8 @@ def load(f, persistent_load=PersistentNdarrayLoad):
:type persistent_load: callable, optional :type persistent_load: callable, optional
""" """
with closing(zipfile.ZipFile(f, 'rb')) as zip_file: with closing(zipfile.ZipFile(f, 'r')) as zip_file:
p = pickle.Unpickler(StringIO(zip_file.open('pkl').read())) p = pickle.Unpickler(BytesIO(zip_file.open('pkl').read()))
p.persistent_load = persistent_load(zip_file) p.persistent_load = persistent_load(zip_file)
return p.load() return p.load()
......
...@@ -46,12 +46,12 @@ def test_dump_load_mrg(): ...@@ -46,12 +46,12 @@ def test_dump_load_mrg():
def test_dump_zip_names(): def test_dump_zip_names():
foo_1 = theano.shared(0, name='foo') foo_1 = theano.shared(0, name='foo')
foo_2 = theano.shared(1, name='foo') foo_2 = theano.shared(1, name='foo')
with open('model.zip', 'w') as f: with open('model.zip', 'wb') as f:
dump((foo_1, foo_2, numpy.array(2)), f) dump((foo_1, foo_2, numpy.array(2)), f)
keys = numpy.load('model.zip').keys() keys = numpy.load('model.zip').keys()
assert keys == ['foo', 'foo_2', 'array_0', 'pkl'] assert keys == ['foo', 'foo_2', 'array_0', 'pkl']
foo = numpy.load('model.zip')['foo'] foo = numpy.load('model.zip')['foo']
assert foo == numpy.array(0) assert foo == numpy.array(0)
with open('model.zip') as f: with open('model.zip', 'rb') as f:
foo_1, foo_2, array = load(f) foo_1, foo_2, array = load(f)
assert array == numpy.array(2) assert array == numpy.array(2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论