提交 f51cb066 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1593 from lamblin/fix_py3_unpickle

Add compatibility unpickler for python 3, use it in test
...@@ -13,6 +13,8 @@ __license__ = "3-clause BSD" ...@@ -13,6 +13,8 @@ __license__ = "3-clause BSD"
import pickle import pickle
import sys import sys
import theano import theano
from theano.compat import PY3
from theano.compat.six import string_types
sys.setrecursionlimit(3000) sys.setrecursionlimit(3000)
...@@ -46,3 +48,50 @@ class StripPickler(Pickler): ...@@ -46,3 +48,50 @@ class StripPickler(Pickler):
del obj.__dict__['__doc__'] del obj.__dict__['__doc__']
return Pickler.save(self, obj) return Pickler.save(self, obj)
# Make an unpickler that tries encoding byte streams before raising TypeError.
# This is useful with python 3, in order to unpickle files created with
# python 2.
# This code is taken from Pandas, https://github.com/pydata/pandas,
# under the same 3-clause BSD license.
def load_reduce(self):
stack = self.stack
args = stack.pop()
func = stack[-1]
try:
value = func(*args)
except Exception:
# try to reencode the arguments
if self.encoding is not None:
new_args = []
for arg in args:
if isinstance(arg, string_types):
new_args.append(arg.encode(self.encoding))
else:
new_args.append(arg)
args = tuple(new_args)
try:
stack[-1] = func(*args)
return
except Exception:
pass
if self.is_verbose:
print(sys.exc_info())
print(func, args)
raise
stack[-1] = value
if PY3:
class CompatUnpickler(pickle._Unpickler):
pass
# Register `load_reduce` defined above in CompatUnpickler
CompatUnpickler.dispatch[pickle.REDUCE[0]] = load_reduce
else:
class CompatUnpickler(pickle.Unpickler):
pass
...@@ -15,7 +15,7 @@ from numpy.testing.noseclasses import KnownFailureTest ...@@ -15,7 +15,7 @@ from numpy.testing.noseclasses import KnownFailureTest
import theano import theano
import theano.scalar as scal import theano.scalar as scal
from theano.compat.six import StringIO from theano.compat.six import PY3, StringIO
from theano import compile from theano import compile
from theano.compile import deep_copy_op, DeepCopyOp from theano.compile import deep_copy_op, DeepCopyOp
from theano import config from theano import config
...@@ -2659,8 +2659,19 @@ class test_shapeoptimizer(unittest.TestCase): ...@@ -2659,8 +2659,19 @@ class test_shapeoptimizer(unittest.TestCase):
pkl_filename = os.path.join(os.path.dirname(theano.__file__), pkl_filename = os.path.join(os.path.dirname(theano.__file__),
'tensor', 'tests', 'shape_opt_cycle.pkl') 'tensor', 'tests', 'shape_opt_cycle.pkl')
fn_args = pickle.load(open(pkl_filename, "rb")) # Due to incompatibilities between python 2 and 3 in the format
# of pickled numpy ndarray, we have to force an encoding
from theano.misc.pkl_utils import CompatUnpickler
pkl_file = open(pkl_filename, "rb")
try:
if PY3:
u = CompatUnpickler(pkl_file, encoding="latin1")
else:
u = CompatUnpickler(pkl_file)
fn_args = u.load()
theano.function(**fn_args) theano.function(**fn_args)
finally:
pkl_file.close()
class test_assert(utt.InferShapeTester): class test_assert(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论