提交 ad55f02a authored 作者: Guillaume Desjardins's avatar Guillaume Desjardins

Recipe for copying random state from one theano graph to another.

上级 2839b9db
...@@ -376,6 +376,71 @@ For example: ...@@ -376,6 +376,71 @@ For example:
>>> rv_u.rng.set_value(rng, borrow=True) >>> rv_u.rng.set_value(rng, borrow=True)
>>> v2 = f() # v2 != v1 >>> v2 = f() # v2 != v1
Copying Random State Between Theano Graphs
------------------------------------------
In some use cases, a user might want to transfer the "state" of all random
number generators associated with a given theano graph (e.g. g1, with compiled
function f1 below) to a second graph (e.g. g2, with function f2). This might
arise for example if you are trying to initialize the state of a model, from
the parameters of a pickled version of a previous model. For
:class:`theano.tensor.shared_randomstreams.RandomStreams` and
:class:`theano.sandbox.rng_mrg.MRG_RandomStreams`
this can be achieved by copying elements of the `state_updates` parameter.
Each time a random variable is drawn from a RandomStreams object, a tuple is
added to the `state_updates` list. The first element is a shared variable,
which represents the state of the random number generator associated with this
*particular* variable, while the second represents the theano graph
corresponding to the random number generation process (i.e. RandomFunction{uniform}.0).
An example of how "random states" can be transferred from one theano function
to another is shown below.
.. code-block:: python
import theano
import numpy
import theano.tensor as T
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.tensor.shared_randomstreams import RandomStreams
class Graph():
def __init__(self, seed=123):
self.rng = RandomStreams(seed)
self.y = self.rng.uniform(size=(1,))
g1 = Graph(seed=123)
f1 = theano.function([], g1.y)
g2 = Graph(seed=987)
f2 = theano.function([], g2.y)
print 'By default, the two functions are out of sync.'
print 'f1() returns ', f1()
print 'f2() returns ', f2()
def copy_random_state(g1, g2):
if isinstance(g1.rng, MRG_RandomStreams):
g2.rng.rstate = g1.rng.rstate
for (su1, su2) in zip(g1.rng.state_updates, g2.rng.state_updates):
su2[0].set_value(su1[0].get_value())
print 'We now copy the state of the theano random number generators.'
copy_random_state(g1, g2)
print 'f1() returns ', f1()
print 'f2() returns ', f2()
This gives the following output:
.. code-block:: shell
By default, the two functions are out of sync.
f1() returns [ 0.72803009]
f2() returns [ 0.55056769]
We now copy the state of the theano random number generators.
f1() returns [ 0.59044123]
f2() returns [ 0.59044123]
Other Random Distributions Other Random Distributions
--------------------------- ---------------------------
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论