提交 4969ddfe authored 作者: Cesar Laurent's avatar Cesar Laurent

Finalized tests and docstring.

上级 f0ff84d9
......@@ -63,6 +63,21 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
the computations of scan (ie they will have to be recomputed
during the gradient computation).
Returns
-------
tuple
Tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the
outputs of ``scan`` (in the same order as in ``outputs_info``).
``updates`` is a subclass of dictionary specifying the update rules for
all shared variables used in scan.
This dictionary should be passed to ``theano.function`` when you compile
your function. The change compared to a normal dictionary is that we
validate that keys are SharedVariable and addition of those dictionary
are validated to be consistent.
Note that only the last time step of ``outputs`` can be used with this
type of scan.
See Also
--------
scan : Looping in Theano.
......@@ -76,6 +91,11 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
if not isinstance(non_sequences, list):
non_sequences = [non_sequences]
# Check that outputs_info has no taps:
for element in outputs_info:
if isinstance(element, dict) and 'taps' in element:
raise RuntimeError("scan_with_checkpoints doesn't work with taps.")
# Determine how many steps the original scan would run
if n_steps is None:
n_steps = sequences[0].shape[0]
......
......@@ -6,25 +6,25 @@ import unittest
import theano
import theano.tensor as T
from pygpu.gpuarray import GpuArrayException
class TestScanCheckpoint(unittest.TestCase):
def setUp(self):
k = T.iscalar("k")
A = T.vector("A")
self.k = k
self.A = A
self.k = T.iscalar("k")
self.A = T.vector("A")
result, _ = theano.scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
outputs_info=T.ones_like(self.A),
non_sequences=self.A,
n_steps=self.k)
result_check, _ = theano.scan_with_checkpoints(
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k,
save_every_N=50)
outputs_info=T.ones_like(self.A),
non_sequences=self.A,
n_steps=self.k,
save_every_N=100)
self.result = result[-1]
self.result_check = result_check[-1]
self.grad_A = T.grad(self.result.sum(), self.A)
......@@ -44,33 +44,22 @@ class TestScanCheckpoint(unittest.TestCase):
out, out_check = f(range(10), 100)
assert numpy.allclose(out, out_check)
@unittest.skipUnless(theano.gpuarray.type._context_reg[None],
'Requires gpuarray backend.')
def test_memory(self):
"""Test that scan_checkpoint reduces memory usage."""
k = T.iscalar("k")
A = T.vector("A")
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
result_check, updates_check = theano.scan_with_checkpoints(
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k,
save_every_N=10000)
result = result[-1]
result_check = result_check[-1]
grad_A = T.grad(result.sum(), A)
grad_A_check = T.grad(result_check.sum(), A)
f = theano.function(inputs=[A, k], outputs=grad_A,
updates=updates + updates_check)
f_check = theano.function(inputs=[A, k], outputs=grad_A_check,
updates=updates + updates_check)
if None not in theano.gpuarray.type.list_contexts():
return unittest.SkipTest('Requires gpuarray backend.')
f = theano.function(inputs=[self.A, self.k],
outputs=self.grad_A)
f_check = theano.function(inputs=[self.A, self.k],
outputs=self.grad_A_check)
free_gmem = theano.gpuarray.type._context_reg[None].free_gmem
data = numpy.ones(free_gmem / 40., dtype=numpy.float32)
data = numpy.ones(free_gmem / 3000, dtype=numpy.float32)
# Check that it works with the checkpoints
f_check(data, 1000000)
f_check(data, 1000)
# Check that the basic scan fails in that case
self.assertRaises(MemoryError, f, data, 1000000)
self.assertRaises(GpuArrayException, f, data, 1000)
def test_taps_error(self):
"""Test that an error rises if we use taps in outputs_info."""
self.assertRaises(RuntimeError, theano.scan_with_checkpoints,
lambda: None, [], {'initial': self.A, 'taps': [-2]})
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论