提交 b02ea197 authored 作者: lamblin's avatar lamblin

Merge pull request #1130 from gdesjardins/master

BUG FIX: MRG_RandomStreams shared state because of member-class initialization.
......@@ -376,6 +376,71 @@ For example:
>>> rv_u.rng.set_value(rng, borrow=True)
>>> v2 = f() # v2 != v1
Copying Random State Between Theano Graphs
------------------------------------------
In some use cases, a user might want to transfer the "state" of all random
number generators associated with a given theano graph (e.g. g1, with compiled
function f1 below) to a second graph (e.g. g2, with function f2). This might
arise for example if you are trying to initialize the state of a model, from
the parameters of a pickled version of a previous model. For
:class:`theano.tensor.shared_randomstreams.RandomStreams` and
:class:`theano.sandbox.rng_mrg.MRG_RandomStreams`
this can be achieved by copying elements of the `state_updates` parameter.
Each time a random variable is drawn from a RandomStreams object, a tuple is
added to the `state_updates` list. The first element is a shared variable,
which represents the state of the random number generator associated with this
*particular* variable, while the second represents the theano graph
corresponding to the random number generation process (i.e. RandomFunction{uniform}.0).
An example of how "random states" can be transferred from one theano function
to another is shown below.
.. code-block:: python
import theano
import numpy
import theano.tensor as T
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.tensor.shared_randomstreams import RandomStreams
class Graph():
def __init__(self, seed=123):
self.rng = RandomStreams(seed)
self.y = self.rng.uniform(size=(1,))
g1 = Graph(seed=123)
f1 = theano.function([], g1.y)
g2 = Graph(seed=987)
f2 = theano.function([], g2.y)
print 'By default, the two functions are out of sync.'
print 'f1() returns ', f1()
print 'f2() returns ', f2()
def copy_random_state(g1, g2):
if isinstance(g1.rng, MRG_RandomStreams):
g2.rng.rstate = g1.rng.rstate
for (su1, su2) in zip(g1.rng.state_updates, g2.rng.state_updates):
su2[0].set_value(su1[0].get_value())
print 'We now copy the state of the theano random number generators.'
copy_random_state(g1, g2)
print 'f1() returns ', f1()
print 'f2() returns ', f2()
This gives the following output:
.. code-block:: shell
By default, the two functions are out of sync.
f1() returns [ 0.72803009]
f2() returns [ 0.55056769]
We now copy the state of the theano random number generators.
f1() returns [ 0.59044123]
f2() returns [ 0.59044123]
Other Random Distributions
---------------------------
......
......@@ -637,10 +637,6 @@ def guess_n_streams(size, warn=True):
class MRG_RandomStreams(object):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)"""
state_updates = []
"""A list of pairs of the form (input_r, output_r), representing the
update rules of all the random states generated by this RandomStreams"""
def updates(self):
return list(self.state_updates)
......@@ -655,6 +651,10 @@ class MRG_RandomStreams(object):
M2 = 2147462579, and not all 0.
"""
# A list of pairs of the form (input_r, output_r), representing the
# update rules of all the random states generated by this RandomStreams.
self.state_updates = []
super(MRG_RandomStreams, self).__init__()
if isinstance(seed, int):
if seed == 0:
......
......@@ -677,3 +677,36 @@ class T_MRG(unittest.TestCase):
self.assertRaises(ValueError, R.binomial, size)
self.assertRaises(ValueError, R.multinomial, size, 1, [])
self.assertRaises(ValueError, R.normal, size)
def test_multiple_rng_aliasing():
"""
Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to
copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function.
"""
rng1 = MRG_RandomStreams(1234)
rng2 = MRG_RandomStreams(2392)
assert rng1.state_updates is not rng2.state_updates
def test_random_state_transfer():
"""
Test that random state can be transferred from one theano graph to another.
"""
class Graph():
def __init__(self, seed=123):
self.rng = MRG_RandomStreams(seed)
self.y = self.rng.uniform(size=(1,))
g1 = Graph(seed=123)
f1 = theano.function([], g1.y)
g2 = Graph(seed=987)
f2 = theano.function([], g2.y)
g2.rng.rstate = g1.rng.rstate
for (su1, su2) in zip(g1.rng.state_updates, g2.rng.state_updates):
su2[0].set_value(su1[0].get_value())
numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
......@@ -122,20 +122,6 @@ class RandomStreams(Component, raw_random.RandomStreamsBase):
"""
random_state_variables = []
"""A list of pairs of the form (input_r, output_r). This will be
over-ridden by the module instance to contain stream
generators.
"""
default_instance_seed = None
"""Instance variable should take None or integer value. Used to
seed the random number generator that provides seeds for member
streams
"""
def __init__(self, seed=None, no_warn=False):
""":type seed: None or int
......@@ -147,7 +133,13 @@ class RandomStreams(Component, raw_random.RandomStreamsBase):
if not no_warn:
deprecation_warning()
super(RandomStreams, self).__init__(no_warn=True)
# A list of pairs of the form (input_r, output_r). This will be
# over-ridden by the module instance to contain stream generators.
self.random_state_variables = []
# Instance variable should take None or integer value. Used to seed the
# random number generator that provides seeds for member streams
self.default_instance_seed = seed
def allocate(self, memo):
......
......@@ -35,27 +35,9 @@ def randomstate_constructor(value, name=None, strict=False,
class RandomStreams(raw_random.RandomStreamsBase):
"""Module component with similar interface to numpy.random
(numpy.random.RandomState)
"""
state_updates = []
"""A list of pairs of the form (input_r, output_r). This will be
over-ridden by the module instance to contain stream
generators.
"""
default_instance_seed = None
"""Instance variable should take None or integer value. Used to
seed the random number generator that provides seeds for member
streams
"""
gen_seedgen = None
"""numpy.RandomState instance that gen() uses to seed new streams.
Module component with similar interface to numpy.random
(numpy.random.RandomState)
"""
def updates(self):
......@@ -71,8 +53,13 @@ class RandomStreams(raw_random.RandomStreamsBase):
"""
super(RandomStreams, self).__init__()
# A list of pairs of the form (input_r, output_r). This will be
# over-ridden by the module instance to contain stream generators.
self.state_updates = []
# Instance variable should take None or integer value. Used to seed the
# random number generator that provides seeds for member streams.
self.default_instance_seed = seed
# numpy.RandomState instance that gen() uses to seed new streams.
self.gen_seedgen = numpy.random.RandomState(seed)
def seed(self, seed=None):
......
......@@ -682,6 +682,17 @@ class T_RandomStreams(unittest.TestCase):
assert val1.dtype == 'int8'
assert numpy.all(abs(val1) <= 1)
def test_multiple_rng(self):
"""
Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to
copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function.
"""
rng1 = RandomStreams(1234)
rng2 = RandomStreams(2392)
assert rng1.random_state_variables is not rng2.random_state_variables
if __name__ == '__main__':
from theano.tests import main
......
......@@ -715,7 +715,36 @@ class T_SharedRandomStreams(unittest.TestCase):
s_rng.set_value(rr, borrow=True)
assert rr is s_rng.container.storage[0]
def test_multiple_rng_aliasing(self):
"""
Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to
copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function.
"""
rng1 = RandomStreams(1234)
rng2 = RandomStreams(2392)
assert rng1.state_updates is not rng2.state_updates
assert rng1.gen_seedgen is not rng2.gen_seedgen
def test_random_state_transfer(self):
"""
Test that random state can be transferred from one theano graph to another.
"""
class Graph():
def __init__(self, seed=123):
self.rng = RandomStreams(seed)
self.y = self.rng.uniform(size=(1,))
g1 = Graph(seed=123)
f1 = function([], g1.y)
g2 = Graph(seed=987)
f2 = function([], g2.y)
for (su1, su2) in zip(g1.rng.state_updates, g2.rng.state_updates):
su2[0].set_value(su1[0].get_value())
numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论