提交 d8f28f6f authored 作者: James Bergstra's avatar James Bergstra

fixed randomstreams module to be shared when it occurs multiple times in a module graph

上级 dd22a040
...@@ -109,7 +109,9 @@ class RandomStreams(Component): ...@@ -109,7 +109,9 @@ class RandomStreams(Component):
def allocate(self, memo): def allocate(self, memo):
"""override `Component.allocate` """ """override `Component.allocate` """
for old_r, new_r in self.random_state_variables: for old_r, new_r in self.random_state_variables:
assert old_r not in memo if old_r in memo:
assert memo[old_r].update is new_r
else:
memo[old_r] = In(old_r, memo[old_r] = In(old_r,
value=Container(old_r, storage=[None]), value=Container(old_r, storage=[None]),
update=new_r, update=new_r,
...@@ -117,7 +119,10 @@ class RandomStreams(Component): ...@@ -117,7 +119,10 @@ class RandomStreams(Component):
def build(self, mode, memo): def build(self, mode, memo):
"""override `Component.build` """ """override `Component.build` """
return RandomStreamsInstance(self, memo, self.default_instance_seed) if self not in memo:
print 'creating RandomStreamsInstance'
memo[self] = RandomStreamsInstance(self, memo, self.default_instance_seed)
return memo[self]
def gen(self, op, *args, **kwargs): def gen(self, op, *args, **kwargs):
"""Create a new random stream in this container. """Create a new random stream in this container.
......
...@@ -122,6 +122,22 @@ class T_RandomStreams(unittest.TestCase): ...@@ -122,6 +122,22 @@ class T_RandomStreams(unittest.TestCase):
assert numpy.all(fn_val0 == numpy_val0) assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1) assert numpy.all(fn_val1 == numpy_val1)
def test_multiple(self):
M = Module()
M.random = RandomStreams(234)
out = M.random.uniform((2,2))
M.m2 = Module()
M.m2.random = M.random
out2 = M.m2.random.uniform((2,2))
M.fn = Method([], out)
M.m2.fn2 = Method([], out2)
m = M.make()
m.random.initialize()
m.m2.initialize()
assert m.random is m.m2.random
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论