提交 72abdcb7 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

I have made scan always clone constants to remove and reference of an old

env. This used to make pickling of scan impossible. As I've broken the pickling mechanism several times before, I've added a test that checks if scan is still picklable. I'm not sure is the best way of doing it though ..
上级 5ac28381
...@@ -696,6 +696,11 @@ def reconstruct_graph(inputs, outputs, tag = None): ...@@ -696,6 +696,11 @@ def reconstruct_graph(inputs, outputs, tag = None):
givens = {} givens = {}
for nw_x, x in zip(nw_inputs, inputs): for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
allinputs = theano.gof.graph.inputs(outputs)
for inp in allinputs:
if isinstance(inp, theano.Constant):
givens[inp] = inp.clone()
nw_outputs = clone( outputs, replace=givens) nw_outputs = clone( outputs, replace=givens)
return (nw_inputs, nw_outputs) return (nw_inputs, nw_outputs)
......
import time import time
import unittest import unittest
import cPickle
import numpy import numpy
import theano import theano
...@@ -187,6 +188,33 @@ class T_Scan(unittest.TestCase): ...@@ -187,6 +188,33 @@ class T_Scan(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
# generator network, only one output , type scalar ; no sequence or
# non sequence arguments
def test_pickling(self):
def f_pow2(x_tm1):
return 2*x_tm1
state = theano.tensor.scalar('state')
n_steps = theano.tensor.iscalar('nsteps')
output, updates = theano.scan(f_pow2, [],state, [],n_steps = n_steps, truncate_gradient
= -1, go_backwards = False)
_my_f = theano.function([state,n_steps], output, updates = updates,
allow_input_downcast = True)
### TESTING PICKLE-ing this function
cPickle.dump(_my_f, open('tmp_scan_test_pickle.pkl','wb'),-1)
my_f = cPickle.load(open('tmp_scan_test_pickle.pkl'))
rng = numpy.random.RandomState(utt.fetch_seed())
state = rng.uniform()
steps = 5
numpy_values = numpy.array([ state*(2**(k+1)) for k
in xrange(steps) ])
theano_values = my_f(state,steps)
assert numpy.allclose(numpy_values,theano_values)
# generator network, only one output , type scalar ; no sequence or # generator network, only one output , type scalar ; no sequence or
# non sequence arguments # non sequence arguments
def test_generator_one_output_scalar(self): def test_generator_one_output_scalar(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论