提交 effeb390 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add compatibility unpickler for python 3, use it in test

上级 36694a6d
......@@ -13,6 +13,8 @@ __license__ = "3-clause BSD"
import pickle
import sys
import theano
from theano.compat import PY3
from theano.compat.six import string_types
sys.setrecursionlimit(3000)
......@@ -46,3 +48,50 @@ class StripPickler(Pickler):
del obj.__dict__['__doc__']
return Pickler.save(self, obj)
# Make an unpickler that tries encoding byte streams before raising TypeError.
# This is useful with python 3, in order to unpickle files created with
# python 2.
# This code is taken from Pandas, https://github.com/pydata/pandas,
# under the same 3-clause BSD license.
def load_reduce(self):
stack = self.stack
args = stack.pop()
func = stack[-1]
try:
value = func(*args)
except Exception:
# try to reencode the arguments
if self.encoding is not None:
new_args = []
for arg in args:
if isinstance(arg, string_types):
new_args.append(arg.encode(self.encoding))
else:
new_args.append(arg)
args = tuple(new_args)
try:
stack[-1] = func(*args)
return
except Exception:
pass
if self.is_verbose:
print(sys.exc_info())
print(func, args)
raise
stack[-1] = value
if PY3:
class CompatUnpickler(pickle._Unpickler):
pass
# Register `load_reduce` defined above in CompatUnpickler
CompatUnpickler.dispatch[pickle.REDUCE[0]] = load_reduce
else:
class CompatUnpickler(pickle.Unpickler):
pass
......@@ -15,7 +15,7 @@ from numpy.testing.noseclasses import KnownFailureTest
import theano
import theano.scalar as scal
from theano.compat.six import StringIO
from theano.compat.six import PY3, StringIO
from theano import compile
from theano.compile import deep_copy_op, DeepCopyOp
from theano import config
......@@ -2659,8 +2659,19 @@ class test_shapeoptimizer(unittest.TestCase):
pkl_filename = os.path.join(os.path.dirname(theano.__file__),
'tensor', 'tests', 'shape_opt_cycle.pkl')
fn_args = pickle.load(open(pkl_filename, "rb"))
theano.function(**fn_args)
# Due to incompatibilities between python 2 and 3 in the format
# of pickled numpy ndarray, we have to force an encoding
from theano.misc.pkl_utils import CompatUnpickler
pkl_file = open(pkl_filename, "rb")
try:
if PY3:
u = CompatUnpickler(pkl_file, encoding="latin1")
else:
u = CompatUnpickler(pkl_file)
fn_args = u.load()
theano.function(**fn_args)
finally:
pkl_file.close()
class test_assert(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论