提交 9b02fa85 authored 作者: Cesar Laurent's avatar Cesar Laurent

Added tests for scant_checkpoint.

上级 e2f11c78
import numpy
import time
import unittest
import theano
import theano.tensor as T
def example1(checkpoint=True):
class TestScanCheckpoint(unittest.TestCase):
k = T.iscalar("k")
A = T.vector("A")
# Symbolic description of the result
if checkpoint:
result, updates = theano.scan_with_checkpoints(
def setUp(self):
k = T.iscalar("k")
A = T.vector("A")
self.k = k
self.A = A
result, _ = theano.scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
result_check, _ = theano.scan_with_checkpoints(
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k,
save_every_N=20)
else:
save_every_N=50)
self.result = result[-1]
self.result_check = result_check[-1]
self.grad_A = T.grad(self.result.sum(), self.A)
self.grad_A_check = T.grad(self.result_check.sum(), self.A)
def test_forward_pass(self):
"""Test forward computation of A**k."""
f = theano.function(inputs=[self.A, self.k],
outputs=[self.result, self.result_check])
out, out_check = f(range(10), 100)
assert numpy.allclose(out, out_check)
def test_backward_pass(self):
"""Test gradient computation of A**k."""
f = theano.function(inputs=[self.A, self.k],
outputs=[self.grad_A, self.grad_A_check])
out, out_check = f(range(10), 100)
assert numpy.allclose(out, out_check)
def test_memory(self):
"""Test that scan_checkpoint reduces memory usage."""
k = T.iscalar("k")
A = T.vector("A")
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
# We only care about A**k, but scan has provided us with A**1 through A**k.
# Discard the values that we don't care about. Scan is smart enough to
# notice this and not waste memory saving them.
result = result[-1]
# compiled function that returns A**k
start_compile = time.time()
power = theano.function(inputs=[A, k], outputs=result, updates=updates)
time_compile = time.time() - start_compile
start_exec = time.time()
out = power(range(10), 100)
time_exec = time.time() - start_exec
if checkpoint:
print("Example 1 with checkpoints")
else:
print("Example 1 without checkpoints")
print("Compile time:", time_compile)
print("Exec time:", time_exec)
print("Output:", out)
def example2(checkpoint=True):
up_to = T.iscalar("up_to")
# define a named function, rather than using lambda
def accumulate_by_adding(arange_val, sum_to_date):
return sum_to_date + arange_val
seq = T.arange(up_to)
outputs_info = T.as_tensor_variable(numpy.asarray(0, seq.dtype))
if checkpoint:
scan_result, scan_updates = theano.scan_with_checkpoints(
fn=accumulate_by_adding,
outputs_info=outputs_info,
sequences=seq,
save_every_N=10)
else:
scan_result, scan_updates = theano.scan(fn=accumulate_by_adding,
outputs_info=outputs_info,
sequences=seq)
start_compile = time.time()
triangular_sequence = theano.function(inputs=[up_to], outputs=scan_result)
time_compile = time.time() - start_compile
start_exec = time.time()
out = triangular_sequence(100)[-1]
time_exec = time.time() - start_exec
if checkpoint:
print("Example 2 with checkpoints")
else:
print("Example 2 without checkpoints")
print("Compile time:", time_compile)
print("Exec time:", time_exec)
print("Output:", out)
def test_scan_checkpoint():
example1(False)
example1(True)
print("----")
example2(False)
example2(True)
print("----")
result_check, updates_check = theano.scan_with_checkpoints(
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k,
save_every_N=10000)
result = result[-1]
result_check = result_check[-1]
grad_A = T.grad(result.sum(), A)
grad_A_check = T.grad(result_check.sum(), A)
f = theano.function(inputs=[A, k], outputs=grad_A,
updates=updates + updates_check)
f_check = theano.function(inputs=[A, k], outputs=grad_A_check,
updates=updates + updates_check)
data = numpy.ones(10000, dtype=theano.config.floatX)
# Check that it works with the checkpoints
f_check(data, 1000000)
# Check that the basic scan fails in that case
self.assertRaises(MemoryError, f, data, 1000000)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论