提交 effeb390 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add compatibility unpickler for python 3, use it in test

上级 36694a6d
...@@ -13,6 +13,8 @@ __license__ = "3-clause BSD" ...@@ -13,6 +13,8 @@ __license__ = "3-clause BSD"
import pickle import pickle
import sys import sys
import theano import theano
from theano.compat import PY3
from theano.compat.six import string_types
sys.setrecursionlimit(3000) sys.setrecursionlimit(3000)
...@@ -46,3 +48,50 @@ class StripPickler(Pickler): ...@@ -46,3 +48,50 @@ class StripPickler(Pickler):
del obj.__dict__['__doc__'] del obj.__dict__['__doc__']
return Pickler.save(self, obj) return Pickler.save(self, obj)
# Make an unpickler that tries encoding byte streams before raising TypeError.
# This is useful with python 3, in order to unpickle files created with
# python 2.
# This code is taken from Pandas, https://github.com/pydata/pandas,
# under the same 3-clause BSD license.
def load_reduce(self):
stack = self.stack
args = stack.pop()
func = stack[-1]
try:
value = func(*args)
except Exception:
# try to reencode the arguments
if self.encoding is not None:
new_args = []
for arg in args:
if isinstance(arg, string_types):
new_args.append(arg.encode(self.encoding))
else:
new_args.append(arg)
args = tuple(new_args)
try:
stack[-1] = func(*args)
return
except Exception:
pass
if self.is_verbose:
print(sys.exc_info())
print(func, args)
raise
stack[-1] = value
if PY3:
class CompatUnpickler(pickle._Unpickler):
pass
# Register `load_reduce` defined above in CompatUnpickler
CompatUnpickler.dispatch[pickle.REDUCE[0]] = load_reduce
else:
class CompatUnpickler(pickle.Unpickler):
pass
...@@ -15,7 +15,7 @@ from numpy.testing.noseclasses import KnownFailureTest ...@@ -15,7 +15,7 @@ from numpy.testing.noseclasses import KnownFailureTest
import theano import theano
import theano.scalar as scal import theano.scalar as scal
from theano.compat.six import StringIO from theano.compat.six import PY3, StringIO
from theano import compile from theano import compile
from theano.compile import deep_copy_op, DeepCopyOp from theano.compile import deep_copy_op, DeepCopyOp
from theano import config from theano import config
...@@ -2659,8 +2659,19 @@ class test_shapeoptimizer(unittest.TestCase): ...@@ -2659,8 +2659,19 @@ class test_shapeoptimizer(unittest.TestCase):
pkl_filename = os.path.join(os.path.dirname(theano.__file__), pkl_filename = os.path.join(os.path.dirname(theano.__file__),
'tensor', 'tests', 'shape_opt_cycle.pkl') 'tensor', 'tests', 'shape_opt_cycle.pkl')
fn_args = pickle.load(open(pkl_filename, "rb")) # Due to incompatibilities between python 2 and 3 in the format
theano.function(**fn_args) # of pickled numpy ndarray, we have to force an encoding
from theano.misc.pkl_utils import CompatUnpickler
pkl_file = open(pkl_filename, "rb")
try:
if PY3:
u = CompatUnpickler(pkl_file, encoding="latin1")
else:
u = CompatUnpickler(pkl_file)
fn_args = u.load()
theano.function(**fn_args)
finally:
pkl_file.close()
class test_assert(utt.InferShapeTester): class test_assert(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论