提交 54c0b49b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test that the storage is really cleared.

上级 b76520e7
...@@ -18,6 +18,7 @@ from theano.compile.pfunc import rebuild_collect_shared ...@@ -18,6 +18,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from theano.gof.python25 import any from theano.gof.python25 import any
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
import theano.scalar.sharedvar import theano.scalar.sharedvar
from theano.scan_module.scan_op import Scan
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
from theano.compat import PY3 from theano.compat import PY3
...@@ -261,6 +262,45 @@ class T_Scan(unittest.TestCase): ...@@ -261,6 +262,45 @@ class T_Scan(unittest.TestCase):
theano_values = my_f(state, steps) theano_values = my_f(state, steps)
utt.assert_allclose(numpy_values, theano_values) utt.assert_allclose(numpy_values, theano_values)
# Test that the inner input_storage and output_storage are
# properly cleared
def test_inner_storage_leak(self):
def f_pow2(x_tm1):
return 2 * x_tm1
state = theano.tensor.scalar('state')
n_steps = theano.tensor.iscalar('nsteps')
output, updates = theano.scan(f_pow2,
[],
state,
[],
n_steps=n_steps)
f = theano.function([state, n_steps],
output,
updates=updates,
allow_input_downcast=True)
scan_node = [node for node in f.maker.fgraph.toposort()
if isinstance(node.op, Scan)]
assert len(scan_node) == 1
scan_node = scan_node[0]
# Make sure they start out as None
assert all(i.value is None for i in scan_node.op.fn.input_storage)
assert all(o.value is None for o in scan_node.op.fn.output_storage)
rng = numpy.random.RandomState(utt.fetch_seed())
state = rng.uniform()
steps = 5
f(state, steps)
# And that they stay that way
assert all(i.value is None for i in scan_node.op.fn.input_storage)
assert all(o.value is None for o in scan_node.op.fn.output_storage)
# generator network, only one output , type scalar ; no sequence or # generator network, only one output , type scalar ; no sequence or
# non sequence arguments # non sequence arguments
def test_generator_one_output_scalar(self): def test_generator_one_output_scalar(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论