提交 e2f11c78 authored 作者: Cesar Laurent's avatar Cesar Laurent

PEP8 and adressed comments.

上级 bac0fde1
import theano
import theano.tensor as T
def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
non_sequences=[], name="checkpointscan_fn",
n_steps=None, save_every_N=10):
"""
Current assumptions :
Current assumptions :
- Every sequence has the same length
- If n_steps is specified, it has the same value as the length of any sequence
- The value of "save_every_N" divides the number of steps the Scan will
......@@ -14,8 +13,8 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
- Only singly-recurrent and non-recurrent outputs are used.
No multiple recurrences.
- Only the last timestep of any output will ever be used.
"""
# Standardize the format of input arguments
if not isinstance(sequences, list):
sequences = [sequences]
......@@ -23,7 +22,7 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
outputs_info = [outputs_info]
if not isinstance(non_sequences, list):
non_sequences = [non_sequences]
# Determine how many steps the original scan would run
if n_steps is None:
n_steps = sequences[0].shape[0]
......@@ -31,14 +30,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
n_steps = n_steps
# Compute the number of steps of the inner and of the outer scan
o_n_steps = n_steps / save_every_N
o_n_steps = theano.tensor.cast(n_steps / save_every_N, 'int64')
i_n_steps = save_every_N
# Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
[s.shape[i] for i in range(1, s.ndim)], s.ndim + 1) for s in sequences]
[s.shape[i] for i in range(1, s.ndim)],
s.ndim + 1) for s in sequences]
new_nitsots = [i for i in outputs_info if i is None]
new_sitsots = [i for i in outputs_info if i is not None]
o_nonsequences = non_sequences + [i_n_steps]
def outer_step(*args):
......@@ -47,14 +46,12 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
i_sequences = list(args[:len(o_sequences)])
i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)])
i_non_sequences = list(args[-len(o_nonsequences):])
# Assemble the correct outputs_info list for the inner_scan
i_outputs_info = []
i_outputs_infos = i_prev_outputs + [None, ] * len(new_nitsots)
# Call the user-provided function with the proper arguments
results, updates = theano.scan(fn=fn,
sequences=i_sequences,
outputs_info=i_prev_outputs + [None,] * len(new_nitsots),
outputs_info=i_outputs_infos,
non_sequences=i_non_sequences[:-1],
name=name + "_inner",
n_steps=i_non_sequences[-1])
......@@ -65,18 +62,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
if not isinstance(results, list):
return results[-1], updates
else:
return [r[-1] for r in results], updates
return [r[-1] for r in results], updates
results, updates = theano.scan(fn=outer_step,
sequences=o_sequences,
outputs_info=outputs_info,
non_sequences=o_nonsequences,
name=name + "_outer",
n_steps=o_n_steps, allow_gc=True)
sequences=o_sequences,
outputs_info=outputs_info,
non_sequences=o_nonsequences,
name=name + "_outer",
n_steps=o_n_steps, allow_gc=True)
# Keep only the last timestep of every output but keep all the updates
return results, updates # TODO is it a bug?
if not isinstance(results, list):
return results[-1:], updates
else:
return [r[-1:] for r in results], updates
return results, updates
......@@ -13,11 +13,11 @@ def example1(checkpoint=True):
# Symbolic description of the result
if checkpoint:
result, updates = theano.scan_with_checkpoints(
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k,
save_every_N=20)
fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k,
save_every_N=20)
else:
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
......@@ -31,13 +31,13 @@ def example1(checkpoint=True):
# compiled function that returns A**k
start_compile = time.time()
power = theano.function(inputs=[A,k], outputs=result, updates=updates)
power = theano.function(inputs=[A, k], outputs=result, updates=updates)
time_compile = time.time() - start_compile
start_exec = time.time()
out = power(range(10), 100)
time_exec = time.time() - start_exec
if checkpoint:
print("Example 1 with checkpoints")
else:
......@@ -45,7 +45,7 @@ def example1(checkpoint=True):
print("Compile time:", time_compile)
print("Exec time:", time_exec)
print("Output:", out)
def example2(checkpoint=True):
......@@ -57,26 +57,26 @@ def example2(checkpoint=True):
seq = T.arange(up_to)
outputs_info = T.as_tensor_variable(numpy.asarray(0, seq.dtype))
if checkpoint:
scan_result, scan_updates = theano.scan_with_checkpoints(
fn=accumulate_by_adding,
outputs_info=outputs_info,
sequences=seq,
save_every_N=10)
fn=accumulate_by_adding,
outputs_info=outputs_info,
sequences=seq,
save_every_N=10)
else:
scan_result, scan_updates = theano.scan(fn=accumulate_by_adding,
outputs_info=outputs_info,
sequences=seq)
start_compile = time.time()
start_compile = time.time()
triangular_sequence = theano.function(inputs=[up_to], outputs=scan_result)
time_compile = time.time() - start_compile
start_exec = time.time()
out = triangular_sequence(100)[-1]
time_exec = time.time() - start_exec
if checkpoint:
print("Example 2 with checkpoints")
else:
......@@ -92,4 +92,4 @@ def test_scan_checkpoint():
print("----")
example2(False)
example2(True)
print("----")
\ No newline at end of file
print("----")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论