提交 44afb3d9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Do not try to pickle `method_descriptor` objects

上级 d9fc9d73
...@@ -866,8 +866,11 @@ def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5], ...@@ -866,8 +866,11 @@ def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5],
def random_make_inplace(node): def random_make_inplace(node):
op = node.op op = node.op
if isinstance(op, RandomFunction) and not op.inplace: if isinstance(op, RandomFunction) and not op.inplace:
new_op = RandomFunction(op.fn, op.outtype, inplace=True, # Read op_fn from op.state, not from op.fn, since op.fn
ndim_added=op.ndim_added) # may not be picklable.
op_fn, op_outtype, op_inplace, op_ndim_added = op.__getstate__()
new_op = RandomFunction(op_fn, op_outtype, inplace=True,
ndim_added=op_ndim_added)
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*node.inputs).outputs
return False return False
......
from __future__ import print_function from __future__ import print_function
__docformat__ = "restructuredtext en"
import numpy import numpy
import pickle
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.raw_random import * from theano.tensor.raw_random import *
...@@ -10,6 +11,8 @@ from theano import tensor ...@@ -10,6 +11,8 @@ from theano import tensor
from theano import compile, config, gof from theano import compile, config, gof
__docformat__ = "restructuredtext en"
class T_random_function(utt.InferShapeTester): class T_random_function(utt.InferShapeTester):
def setUp(self): def setUp(self):
...@@ -1181,6 +1184,19 @@ class T_random_function(utt.InferShapeTester): ...@@ -1181,6 +1184,19 @@ class T_random_function(utt.InferShapeTester):
pvals_val], RandomFunction) pvals_val], RandomFunction)
""" """
def test_pkl(self):
# Test pickling of RandomFunction.
# binomial was created by calling RandomFunction on a string,
# random_integers by calling it on a function.
rng_r = random_state_type()
post_bin_r, bin_sample = binomial(rng_r, (3, 5), 1, .3)
f = theano.function([rng_r], [post_bin_r, bin_sample])
pkl_f = pickle.dumps(f)
post_int_r, int_sample = random_integers(rng_r, (3, 5), -1, 8)
g = theano.function([rng_r], [post_int_r, int_sample])
pkl_g = pickle.dumps(g)
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论