提交 107c28d1 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix RandomFunction.__getstate__

上级 f0dd8a2c
......@@ -133,15 +133,28 @@ class RandomFunction(gof.Op):
self.__setstate__([fn, outtype, inplace, ndim_added])
def __getstate__(self):
return self.state
d = dict(self.__dict__)
del d['exec_fn']
if 'destroy_map' in d:
del d['destroy_map']
return d
def __setstate__(self, dct):
if isinstance(dct, dict):
state = [dct['fn'],
dct['outtype'],
dct['inplace'],
dct['ndim_added']]
else:
state = dct
def __setstate__(self, state):
self.state = state
fn, outtype, inplace, ndim_added = state
self.fn = fn
if isinstance(fn, string_types):
self.fn = getattr(numpy.random.RandomState, fn)
self.exec_fn = getattr(numpy.random.RandomState, fn)
else:
self.fn = fn
self.exec_fn = fn
self.outtype = outtype
self.inplace = inplace
if self.inplace:
......@@ -149,7 +162,7 @@ class RandomFunction(gof.Op):
self.ndim_added = ndim_added
def __str__(self):
return 'RandomFunction{%s}' % self.fn.__name__
return 'RandomFunction{%s}' % self.exec_fn.__name__
def make_node(self, r, shape, *args):
"""
......@@ -247,7 +260,7 @@ class RandomFunction(gof.Op):
if not self.inplace:
r = copy(r)
rout[0] = r
rval = self.fn(r, *(args + [shape]))
rval = self.exec_fn(r, *(args + [shape]))
if (not isinstance(rval, numpy.ndarray) or
str(rval.dtype) != node.outputs[1].type.dtype):
rval = theano._asarray(rval, dtype=node.outputs[1].type.dtype)
......@@ -904,7 +917,7 @@ def random_make_inplace(node):
if isinstance(op, RandomFunction) and not op.inplace:
# Read op_fn from op.state, not from op.fn, since op.fn
# may not be picklable.
op_fn, op_outtype, op_inplace, op_ndim_added = op.__getstate__()
op_fn, op_outtype, op_inplace, op_ndim_added = op._props()
new_op = RandomFunction(op_fn, op_outtype, inplace=True,
ndim_added=op_ndim_added)
return new_op.make_node(*node.inputs).outputs
......
......@@ -1199,6 +1199,7 @@ class T_random_function(utt.InferShapeTester):
post_int_r, int_sample = random_integers(rng_r, (3, 5), -1, 8)
g = theano.function([rng_r], [post_int_r, int_sample], mode=mode)
pkl_g = pickle.dumps(g)
pickle.loads(pkl_g)
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论