提交 f0ff84d9 authored 作者: Cesar Laurent's avatar Cesar Laurent

pep8 and changed memory_test.

上级 b39ce3fd
...@@ -18,11 +18,11 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -18,11 +18,11 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
Current assumptions : Current assumptions :
- Every sequence has the same length. - Every sequence has the same length.
- If n_steps is specified, it has the same value as the length of any - If n_steps is specified, it has the same value as the length of any
sequence. sequence.
- The value of "save_every_N" divides the number of steps the Scan will - The value of "save_every_N" divides the number of steps the Scan will
run without remainder. run without remainder.
- Only singly-recurrent and non-recurrent outputs are used. - Only singly-recurrent and non-recurrent outputs are used.
No multiple recurrences. No multiple recurrences.
- Only the last timestep of any output will ever be used. - Only the last timestep of any output will ever be used.
Parameters Parameters
......
...@@ -44,6 +44,8 @@ class TestScanCheckpoint(unittest.TestCase): ...@@ -44,6 +44,8 @@ class TestScanCheckpoint(unittest.TestCase):
out, out_check = f(range(10), 100) out, out_check = f(range(10), 100)
assert numpy.allclose(out, out_check) assert numpy.allclose(out, out_check)
@unittest.skipUnless(theano.gpuarray.type._context_reg[None],
'Requires gpuarray backend.')
def test_memory(self): def test_memory(self):
"""Test that scan_checkpoint reduces memory usage.""" """Test that scan_checkpoint reduces memory usage."""
k = T.iscalar("k") k = T.iscalar("k")
...@@ -66,7 +68,8 @@ class TestScanCheckpoint(unittest.TestCase): ...@@ -66,7 +68,8 @@ class TestScanCheckpoint(unittest.TestCase):
updates=updates + updates_check) updates=updates + updates_check)
f_check = theano.function(inputs=[A, k], outputs=grad_A_check, f_check = theano.function(inputs=[A, k], outputs=grad_A_check,
updates=updates + updates_check) updates=updates + updates_check)
data = numpy.ones(10000, dtype=theano.config.floatX) free_gmem = theano.gpuarray.type._context_reg[None].free_gmem
data = numpy.ones(free_gmem / 40., dtype=numpy.float32)
# Check that it works with the checkpoints # Check that it works with the checkpoints
f_check(data, 1000000) f_check(data, 1000000)
# Check that the basic scan fails in that case # Check that the basic scan fails in that case
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论