提交 7eafc8ae authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a test for the broken pickle workaround.

上级 32b5c779
...@@ -566,6 +566,36 @@ class T_picklefunction(unittest.TestCase): ...@@ -566,6 +566,36 @@ class T_picklefunction(unittest.TestCase):
assert numpy.all(nl[6][nl[2]] == numpy.asarray([2, 3., 4])) assert numpy.all(nl[6][nl[2]] == numpy.asarray([2, 3., 4]))
def test_broken_pickle_with_shared(self):
saves = []
def pers_save(obj):
if isinstance(obj, numpy.ndarray):
saves.append(obj)
return len(saves)-1
else:
return None
def pers_load(id):
return saves[id]
a = numpy.random.rand(4, 5)
b = numpy.random.rand(5, 4)
x = theano.tensor.matrix()
y = theano.shared(b)
f = theano.function([x], theano.tensor.dot(x, y))
import StringIO
fp = StringIO.StringIO()
p = cPickle.Pickler(fp, 2)
p.persistent_id = pers_save
p.dump(f)
fp2 = StringIO.StringIO(fp.getvalue())
fp.close()
p = cPickle.Unpickler(fp2)
p.persistent_load = pers_load
f2 = p.load()
fp2.close()
def test_pickle_class_with_functions(self): def test_pickle_class_with_functions(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论