提交 9b02fa85 authored 作者: Cesar Laurent's avatar Cesar Laurent

Added tests for scant_checkpoint.

上级 e2f11c78
import numpy import numpy
import time import unittest
import theano import theano
import theano.tensor as T import theano.tensor as T
def example1(checkpoint=True): class TestScanCheckpoint(unittest.TestCase):
k = T.iscalar("k") def setUp(self):
A = T.vector("A") k = T.iscalar("k")
A = T.vector("A")
# Symbolic description of the result self.k = k
if checkpoint: self.A = A
result, updates = theano.scan_with_checkpoints( result, _ = theano.scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
result_check, _ = theano.scan_with_checkpoints(
fn=lambda prior_result, A: prior_result * A, fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A), outputs_info=T.ones_like(A),
non_sequences=A, non_sequences=A,
n_steps=k, n_steps=k,
save_every_N=20) save_every_N=50)
else: self.result = result[-1]
self.result_check = result_check[-1]
self.grad_A = T.grad(self.result.sum(), self.A)
self.grad_A_check = T.grad(self.result_check.sum(), self.A)
def test_forward_pass(self):
"""Test forward computation of A**k."""
f = theano.function(inputs=[self.A, self.k],
outputs=[self.result, self.result_check])
out, out_check = f(range(10), 100)
assert numpy.allclose(out, out_check)
def test_backward_pass(self):
"""Test gradient computation of A**k."""
f = theano.function(inputs=[self.A, self.k],
outputs=[self.grad_A, self.grad_A_check])
out, out_check = f(range(10), 100)
assert numpy.allclose(out, out_check)
def test_memory(self):
"""Test that scan_checkpoint reduces memory usage."""
k = T.iscalar("k")
A = T.vector("A")
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A, result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A), outputs_info=T.ones_like(A),
non_sequences=A, non_sequences=A,
n_steps=k) n_steps=k)
result_check, updates_check = theano.scan_with_checkpoints(
# We only care about A**k, but scan has provided us with A**1 through A**k. fn=lambda prior_result, A: prior_result * A,
# Discard the values that we don't care about. Scan is smart enough to outputs_info=T.ones_like(A),
# notice this and not waste memory saving them. non_sequences=A,
result = result[-1] n_steps=k,
save_every_N=10000)
# compiled function that returns A**k result = result[-1]
start_compile = time.time() result_check = result_check[-1]
power = theano.function(inputs=[A, k], outputs=result, updates=updates) grad_A = T.grad(result.sum(), A)
time_compile = time.time() - start_compile grad_A_check = T.grad(result_check.sum(), A)
f = theano.function(inputs=[A, k], outputs=grad_A,
start_exec = time.time() updates=updates + updates_check)
out = power(range(10), 100) f_check = theano.function(inputs=[A, k], outputs=grad_A_check,
time_exec = time.time() - start_exec updates=updates + updates_check)
data = numpy.ones(10000, dtype=theano.config.floatX)
if checkpoint: # Check that it works with the checkpoints
print("Example 1 with checkpoints") f_check(data, 1000000)
else: # Check that the basic scan fails in that case
print("Example 1 without checkpoints") self.assertRaises(MemoryError, f, data, 1000000)
print("Compile time:", time_compile)
print("Exec time:", time_exec)
print("Output:", out)
def example2(checkpoint=True):
up_to = T.iscalar("up_to")
# define a named function, rather than using lambda
def accumulate_by_adding(arange_val, sum_to_date):
return sum_to_date + arange_val
seq = T.arange(up_to)
outputs_info = T.as_tensor_variable(numpy.asarray(0, seq.dtype))
if checkpoint:
scan_result, scan_updates = theano.scan_with_checkpoints(
fn=accumulate_by_adding,
outputs_info=outputs_info,
sequences=seq,
save_every_N=10)
else:
scan_result, scan_updates = theano.scan(fn=accumulate_by_adding,
outputs_info=outputs_info,
sequences=seq)
start_compile = time.time()
triangular_sequence = theano.function(inputs=[up_to], outputs=scan_result)
time_compile = time.time() - start_compile
start_exec = time.time()
out = triangular_sequence(100)[-1]
time_exec = time.time() - start_exec
if checkpoint:
print("Example 2 with checkpoints")
else:
print("Example 2 without checkpoints")
print("Compile time:", time_compile)
print("Exec time:", time_exec)
print("Output:", out)
def test_scan_checkpoint():
example1(False)
example1(True)
print("----")
example2(False)
example2(True)
print("----")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论