提交 859bba60 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3050 from lamblin/fix_randomstreams_pkl

Do not try to pickle `method_descriptor` objects
......@@ -866,8 +866,11 @@ def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5],
def random_make_inplace(node):
op = node.op
if isinstance(op, RandomFunction) and not op.inplace:
new_op = RandomFunction(op.fn, op.outtype, inplace=True,
ndim_added=op.ndim_added)
# Read op_fn from op.state, not from op.fn, since op.fn
# may not be picklable.
op_fn, op_outtype, op_inplace, op_ndim_added = op.__getstate__()
new_op = RandomFunction(op_fn, op_outtype, inplace=True,
ndim_added=op_ndim_added)
return new_op.make_node(*node.inputs).outputs
return False
......
from __future__ import print_function
__docformat__ = "restructuredtext en"
import numpy
import pickle
from theano.tests import unittest_tools as utt
from theano.tensor.raw_random import *
......@@ -10,6 +11,8 @@ from theano import tensor
from theano import compile, config, gof
__docformat__ = "restructuredtext en"
class T_random_function(utt.InferShapeTester):
def setUp(self):
......@@ -1181,6 +1184,19 @@ class T_random_function(utt.InferShapeTester):
pvals_val], RandomFunction)
"""
def test_pkl(self):
# Test pickling of RandomFunction.
# binomial was created by calling RandomFunction on a string,
# random_integers by calling it on a function.
rng_r = random_state_type()
post_bin_r, bin_sample = binomial(rng_r, (3, 5), 1, .3)
f = theano.function([rng_r], [post_bin_r, bin_sample])
pkl_f = pickle.dumps(f)
post_int_r, int_sample = random_integers(rng_r, (3, 5), -1, 8)
g = theano.function([rng_r], [post_int_r, int_sample])
pkl_g = pickle.dumps(g)
if __name__ == '__main__':
from theano.tests import main
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论