提交 d8f28f6f authored 作者: James Bergstra's avatar James Bergstra

fixed randomstreams module to be shared when it occurs multiple times in a module graph

上级 dd22a040
......@@ -109,15 +109,20 @@ class RandomStreams(Component):
def allocate(self, memo):
"""override `Component.allocate` """
for old_r, new_r in self.random_state_variables:
assert old_r not in memo
memo[old_r] = In(old_r,
value=Container(old_r, storage=[None]),
update=new_r,
mutable=True)
if old_r in memo:
assert memo[old_r].update is new_r
else:
memo[old_r] = In(old_r,
value=Container(old_r, storage=[None]),
update=new_r,
mutable=True)
def build(self, mode, memo):
"""override `Component.build` """
return RandomStreamsInstance(self, memo, self.default_instance_seed)
if self not in memo:
print 'creating RandomStreamsInstance'
memo[self] = RandomStreamsInstance(self, memo, self.default_instance_seed)
return memo[self]
def gen(self, op, *args, **kwargs):
"""Create a new random stream in this container.
......
......@@ -122,6 +122,22 @@ class T_RandomStreams(unittest.TestCase):
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_multiple(self):
M = Module()
M.random = RandomStreams(234)
out = M.random.uniform((2,2))
M.m2 = Module()
M.m2.random = M.random
out2 = M.m2.random.uniform((2,2))
M.fn = Method([], out)
M.m2.fn2 = Method([], out2)
m = M.make()
m.random.initialize()
m.m2.initialize()
assert m.random is m.m2.random
if __name__ == '__main__':
from theano.tests import main
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论