提交 064ee438 authored 作者: Bart van Merrienboer's avatar Bart van Merrienboer

Remove weak references, which aren't picklable

上级 1925556b
...@@ -6,7 +6,6 @@ http://www.iro.umontreal.ca/~simardr/ssj/indexe.html ...@@ -6,7 +6,6 @@ http://www.iro.umontreal.ca/~simardr/ssj/indexe.html
""" """
import warnings import warnings
import weakref
import numpy import numpy
...@@ -1222,7 +1221,7 @@ class MRG_RandomStreams(object): ...@@ -1222,7 +1221,7 @@ class MRG_RandomStreams(object):
*mrg_uniform.new(node_rstate, *mrg_uniform.new(node_rstate,
ndim, dtype, size)) ndim, dtype, size))
# Add a reference to distinguish from other shared variables # Add a reference to distinguish from other shared variables
node_rstate.rng_owner = weakref.ref(self) node_rstate.tag.is_rng = True
r = u * (high - low) + low r = u * (high - low) + low
if u.type.broadcastable != r.type.broadcastable: if u.type.broadcastable != r.type.broadcastable:
......
...@@ -5,7 +5,6 @@ __docformat__ = "restructuredtext en" ...@@ -5,7 +5,6 @@ __docformat__ = "restructuredtext en"
import copy import copy
import numpy import numpy
import weakref
from theano.compile.sharedvalue import (SharedVariable, shared_constructor, from theano.compile.sharedvalue import (SharedVariable, shared_constructor,
shared) shared)
...@@ -134,7 +133,7 @@ class RandomStreams(raw_random.RandomStreamsBase): ...@@ -134,7 +133,7 @@ class RandomStreams(raw_random.RandomStreamsBase):
seed = int(self.gen_seedgen.randint(2 ** 30)) seed = int(self.gen_seedgen.randint(2 ** 30))
random_state_variable = shared(numpy.random.RandomState(seed)) random_state_variable = shared(numpy.random.RandomState(seed))
# Add a reference to distinguish from other shared variables # Add a reference to distinguish from other shared variables
random_state_variable.rng_owner = weakref.ref(self) random_state_variable.tag.is_rng = True
new_r, out = op(random_state_variable, *args, **kwargs) new_r, out = op(random_state_variable, *args, **kwargs)
out.rng = random_state_variable out.rng = random_state_variable
out.update = (random_state_variable, new_r) out.update = (random_state_variable, new_r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论