提交 107c28d1 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix RandomFunction.__getstate__

上级 f0dd8a2c
...@@ -133,15 +133,28 @@ class RandomFunction(gof.Op): ...@@ -133,15 +133,28 @@ class RandomFunction(gof.Op):
self.__setstate__([fn, outtype, inplace, ndim_added]) self.__setstate__([fn, outtype, inplace, ndim_added])
def __getstate__(self): def __getstate__(self):
return self.state d = dict(self.__dict__)
del d['exec_fn']
if 'destroy_map' in d:
del d['destroy_map']
return d
def __setstate__(self, dct):
if isinstance(dct, dict):
state = [dct['fn'],
dct['outtype'],
dct['inplace'],
dct['ndim_added']]
else:
state = dct
def __setstate__(self, state):
self.state = state self.state = state
fn, outtype, inplace, ndim_added = state fn, outtype, inplace, ndim_added = state
self.fn = fn
if isinstance(fn, string_types): if isinstance(fn, string_types):
self.fn = getattr(numpy.random.RandomState, fn) self.exec_fn = getattr(numpy.random.RandomState, fn)
else: else:
self.fn = fn self.exec_fn = fn
self.outtype = outtype self.outtype = outtype
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
...@@ -149,7 +162,7 @@ class RandomFunction(gof.Op): ...@@ -149,7 +162,7 @@ class RandomFunction(gof.Op):
self.ndim_added = ndim_added self.ndim_added = ndim_added
def __str__(self): def __str__(self):
return 'RandomFunction{%s}' % self.fn.__name__ return 'RandomFunction{%s}' % self.exec_fn.__name__
def make_node(self, r, shape, *args): def make_node(self, r, shape, *args):
""" """
...@@ -247,7 +260,7 @@ class RandomFunction(gof.Op): ...@@ -247,7 +260,7 @@ class RandomFunction(gof.Op):
if not self.inplace: if not self.inplace:
r = copy(r) r = copy(r)
rout[0] = r rout[0] = r
rval = self.fn(r, *(args + [shape])) rval = self.exec_fn(r, *(args + [shape]))
if (not isinstance(rval, numpy.ndarray) or if (not isinstance(rval, numpy.ndarray) or
str(rval.dtype) != node.outputs[1].type.dtype): str(rval.dtype) != node.outputs[1].type.dtype):
rval = theano._asarray(rval, dtype=node.outputs[1].type.dtype) rval = theano._asarray(rval, dtype=node.outputs[1].type.dtype)
...@@ -904,7 +917,7 @@ def random_make_inplace(node): ...@@ -904,7 +917,7 @@ def random_make_inplace(node):
if isinstance(op, RandomFunction) and not op.inplace: if isinstance(op, RandomFunction) and not op.inplace:
# Read op_fn from op.state, not from op.fn, since op.fn # Read op_fn from op.state, not from op.fn, since op.fn
# may not be picklable. # may not be picklable.
op_fn, op_outtype, op_inplace, op_ndim_added = op.__getstate__() op_fn, op_outtype, op_inplace, op_ndim_added = op._props()
new_op = RandomFunction(op_fn, op_outtype, inplace=True, new_op = RandomFunction(op_fn, op_outtype, inplace=True,
ndim_added=op_ndim_added) ndim_added=op_ndim_added)
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*node.inputs).outputs
......
...@@ -1199,6 +1199,7 @@ class T_random_function(utt.InferShapeTester): ...@@ -1199,6 +1199,7 @@ class T_random_function(utt.InferShapeTester):
post_int_r, int_sample = random_integers(rng_r, (3, 5), -1, 8) post_int_r, int_sample = random_integers(rng_r, (3, 5), -1, 8)
g = theano.function([rng_r], [post_int_r, int_sample], mode=mode) g = theano.function([rng_r], [post_int_r, int_sample], mode=mode)
pkl_g = pickle.dumps(g) pkl_g = pickle.dumps(g)
pickle.loads(pkl_g)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论