提交 e2f11c78 authored 作者: Cesar Laurent's avatar Cesar Laurent

PEP8 and adressed comments.

上级 bac0fde1
import theano import theano
import theano.tensor as T
def scan_with_checkpoints(fn, sequences=[], outputs_info=None, def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
non_sequences=[], name="checkpointscan_fn", non_sequences=[], name="checkpointscan_fn",
n_steps=None, save_every_N=10): n_steps=None, save_every_N=10):
""" """
Current assumptions : Current assumptions :
- Every sequence has the same length - Every sequence has the same length
- If n_steps is specified, it has the same value as the length of any sequence - If n_steps is specified, it has the same value as the length of any sequence
- The value of "save_every_N" divides the number of steps the Scan will - The value of "save_every_N" divides the number of steps the Scan will
...@@ -14,8 +13,8 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -14,8 +13,8 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
- Only singly-recurrent and non-recurrent outputs are used. - Only singly-recurrent and non-recurrent outputs are used.
No multiple recurrences. No multiple recurrences.
- Only the last timestep of any output will ever be used. - Only the last timestep of any output will ever be used.
""" """
# Standardize the format of input arguments # Standardize the format of input arguments
if not isinstance(sequences, list): if not isinstance(sequences, list):
sequences = [sequences] sequences = [sequences]
...@@ -23,7 +22,7 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -23,7 +22,7 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
outputs_info = [outputs_info] outputs_info = [outputs_info]
if not isinstance(non_sequences, list): if not isinstance(non_sequences, list):
non_sequences = [non_sequences] non_sequences = [non_sequences]
# Determine how many steps the original scan would run # Determine how many steps the original scan would run
if n_steps is None: if n_steps is None:
n_steps = sequences[0].shape[0] n_steps = sequences[0].shape[0]
...@@ -31,14 +30,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -31,14 +30,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
n_steps = n_steps n_steps = n_steps
# Compute the number of steps of the inner and of the outer scan # Compute the number of steps of the inner and of the outer scan
o_n_steps = n_steps / save_every_N o_n_steps = theano.tensor.cast(n_steps / save_every_N, 'int64')
i_n_steps = save_every_N i_n_steps = save_every_N
# Establish the input variables of the outer scan # Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] + o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
[s.shape[i] for i in range(1, s.ndim)], s.ndim + 1) for s in sequences] [s.shape[i] for i in range(1, s.ndim)],
s.ndim + 1) for s in sequences]
new_nitsots = [i for i in outputs_info if i is None] new_nitsots = [i for i in outputs_info if i is None]
new_sitsots = [i for i in outputs_info if i is not None]
o_nonsequences = non_sequences + [i_n_steps] o_nonsequences = non_sequences + [i_n_steps]
def outer_step(*args): def outer_step(*args):
...@@ -47,14 +46,12 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -47,14 +46,12 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
i_sequences = list(args[:len(o_sequences)]) i_sequences = list(args[:len(o_sequences)])
i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)]) i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)])
i_non_sequences = list(args[-len(o_nonsequences):]) i_non_sequences = list(args[-len(o_nonsequences):])
i_outputs_infos = i_prev_outputs + [None, ] * len(new_nitsots)
# Assemble the correct outputs_info list for the inner_scan
i_outputs_info = []
# Call the user-provided function with the proper arguments # Call the user-provided function with the proper arguments
results, updates = theano.scan(fn=fn, results, updates = theano.scan(fn=fn,
sequences=i_sequences, sequences=i_sequences,
outputs_info=i_prev_outputs + [None,] * len(new_nitsots), outputs_info=i_outputs_infos,
non_sequences=i_non_sequences[:-1], non_sequences=i_non_sequences[:-1],
name=name + "_inner", name=name + "_inner",
n_steps=i_non_sequences[-1]) n_steps=i_non_sequences[-1])
...@@ -65,18 +62,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -65,18 +62,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
if not isinstance(results, list): if not isinstance(results, list):
return results[-1], updates return results[-1], updates
else: else:
return [r[-1] for r in results], updates return [r[-1] for r in results], updates
results, updates = theano.scan(fn=outer_step, results, updates = theano.scan(fn=outer_step,
sequences=o_sequences, sequences=o_sequences,
outputs_info=outputs_info, outputs_info=outputs_info,
non_sequences=o_nonsequences, non_sequences=o_nonsequences,
name=name + "_outer", name=name + "_outer",
n_steps=o_n_steps, allow_gc=True) n_steps=o_n_steps, allow_gc=True)
# Keep only the last timestep of every output but keep all the updates # Keep only the last timestep of every output but keep all the updates
return results, updates # TODO is it a bug? return results, updates
if not isinstance(results, list):
return results[-1:], updates
else:
return [r[-1:] for r in results], updates
...@@ -13,11 +13,11 @@ def example1(checkpoint=True): ...@@ -13,11 +13,11 @@ def example1(checkpoint=True):
# Symbolic description of the result # Symbolic description of the result
if checkpoint: if checkpoint:
result, updates = theano.scan_with_checkpoints( result, updates = theano.scan_with_checkpoints(
fn=lambda prior_result, A: prior_result * A, fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A), outputs_info=T.ones_like(A),
non_sequences=A, non_sequences=A,
n_steps=k, n_steps=k,
save_every_N=20) save_every_N=20)
else: else:
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A, result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A), outputs_info=T.ones_like(A),
...@@ -31,13 +31,13 @@ def example1(checkpoint=True): ...@@ -31,13 +31,13 @@ def example1(checkpoint=True):
# compiled function that returns A**k # compiled function that returns A**k
start_compile = time.time() start_compile = time.time()
power = theano.function(inputs=[A,k], outputs=result, updates=updates) power = theano.function(inputs=[A, k], outputs=result, updates=updates)
time_compile = time.time() - start_compile time_compile = time.time() - start_compile
start_exec = time.time() start_exec = time.time()
out = power(range(10), 100) out = power(range(10), 100)
time_exec = time.time() - start_exec time_exec = time.time() - start_exec
if checkpoint: if checkpoint:
print("Example 1 with checkpoints") print("Example 1 with checkpoints")
else: else:
...@@ -45,7 +45,7 @@ def example1(checkpoint=True): ...@@ -45,7 +45,7 @@ def example1(checkpoint=True):
print("Compile time:", time_compile) print("Compile time:", time_compile)
print("Exec time:", time_exec) print("Exec time:", time_exec)
print("Output:", out) print("Output:", out)
def example2(checkpoint=True): def example2(checkpoint=True):
...@@ -57,26 +57,26 @@ def example2(checkpoint=True): ...@@ -57,26 +57,26 @@ def example2(checkpoint=True):
seq = T.arange(up_to) seq = T.arange(up_to)
outputs_info = T.as_tensor_variable(numpy.asarray(0, seq.dtype)) outputs_info = T.as_tensor_variable(numpy.asarray(0, seq.dtype))
if checkpoint: if checkpoint:
scan_result, scan_updates = theano.scan_with_checkpoints( scan_result, scan_updates = theano.scan_with_checkpoints(
fn=accumulate_by_adding, fn=accumulate_by_adding,
outputs_info=outputs_info, outputs_info=outputs_info,
sequences=seq, sequences=seq,
save_every_N=10) save_every_N=10)
else: else:
scan_result, scan_updates = theano.scan(fn=accumulate_by_adding, scan_result, scan_updates = theano.scan(fn=accumulate_by_adding,
outputs_info=outputs_info, outputs_info=outputs_info,
sequences=seq) sequences=seq)
start_compile = time.time() start_compile = time.time()
triangular_sequence = theano.function(inputs=[up_to], outputs=scan_result) triangular_sequence = theano.function(inputs=[up_to], outputs=scan_result)
time_compile = time.time() - start_compile time_compile = time.time() - start_compile
start_exec = time.time() start_exec = time.time()
out = triangular_sequence(100)[-1] out = triangular_sequence(100)[-1]
time_exec = time.time() - start_exec time_exec = time.time() - start_exec
if checkpoint: if checkpoint:
print("Example 2 with checkpoints") print("Example 2 with checkpoints")
else: else:
...@@ -92,4 +92,4 @@ def test_scan_checkpoint(): ...@@ -92,4 +92,4 @@ def test_scan_checkpoint():
print("----") print("----")
example2(False) example2(False)
example2(True) example2(True)
print("----") print("----")
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论