提交 e2f11c78 authored 作者: Cesar Laurent's avatar Cesar Laurent

PEP8 and adressed comments.

上级 bac0fde1
import theano import theano
import theano.tensor as T
def scan_with_checkpoints(fn, sequences=[], outputs_info=None, def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
non_sequences=[], name="checkpointscan_fn", non_sequences=[], name="checkpointscan_fn",
n_steps=None, save_every_N=10): n_steps=None, save_every_N=10):
""" """
Current assumptions : Current assumptions :
- Every sequence has the same length - Every sequence has the same length
...@@ -14,8 +13,8 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -14,8 +13,8 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
- Only singly-recurrent and non-recurrent outputs are used. - Only singly-recurrent and non-recurrent outputs are used.
No multiple recurrences. No multiple recurrences.
- Only the last timestep of any output will ever be used. - Only the last timestep of any output will ever be used.
"""
"""
# Standardize the format of input arguments # Standardize the format of input arguments
if not isinstance(sequences, list): if not isinstance(sequences, list):
sequences = [sequences] sequences = [sequences]
...@@ -31,14 +30,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -31,14 +30,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
n_steps = n_steps n_steps = n_steps
# Compute the number of steps of the inner and of the outer scan # Compute the number of steps of the inner and of the outer scan
o_n_steps = n_steps / save_every_N o_n_steps = theano.tensor.cast(n_steps / save_every_N, 'int64')
i_n_steps = save_every_N i_n_steps = save_every_N
# Establish the input variables of the outer scan # Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] + o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
[s.shape[i] for i in range(1, s.ndim)], s.ndim + 1) for s in sequences] [s.shape[i] for i in range(1, s.ndim)],
s.ndim + 1) for s in sequences]
new_nitsots = [i for i in outputs_info if i is None] new_nitsots = [i for i in outputs_info if i is None]
new_sitsots = [i for i in outputs_info if i is not None]
o_nonsequences = non_sequences + [i_n_steps] o_nonsequences = non_sequences + [i_n_steps]
def outer_step(*args): def outer_step(*args):
...@@ -47,14 +46,12 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -47,14 +46,12 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
i_sequences = list(args[:len(o_sequences)]) i_sequences = list(args[:len(o_sequences)])
i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)]) i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)])
i_non_sequences = list(args[-len(o_nonsequences):]) i_non_sequences = list(args[-len(o_nonsequences):])
i_outputs_infos = i_prev_outputs + [None, ] * len(new_nitsots)
# Assemble the correct outputs_info list for the inner_scan
i_outputs_info = []
# Call the user-provided function with the proper arguments # Call the user-provided function with the proper arguments
results, updates = theano.scan(fn=fn, results, updates = theano.scan(fn=fn,
sequences=i_sequences, sequences=i_sequences,
outputs_info=i_prev_outputs + [None,] * len(new_nitsots), outputs_info=i_outputs_infos,
non_sequences=i_non_sequences[:-1], non_sequences=i_non_sequences[:-1],
name=name + "_inner", name=name + "_inner",
n_steps=i_non_sequences[-1]) n_steps=i_non_sequences[-1])
...@@ -75,8 +72,4 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None, ...@@ -75,8 +72,4 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
n_steps=o_n_steps, allow_gc=True) n_steps=o_n_steps, allow_gc=True)
# Keep only the last timestep of every output but keep all the updates # Keep only the last timestep of every output but keep all the updates
return results, updates # TODO is it a bug? return results, updates
if not isinstance(results, list):
return results[-1:], updates
else:
return [r[-1:] for r in results], updates
...@@ -31,7 +31,7 @@ def example1(checkpoint=True): ...@@ -31,7 +31,7 @@ def example1(checkpoint=True):
# compiled function that returns A**k # compiled function that returns A**k
start_compile = time.time() start_compile = time.time()
power = theano.function(inputs=[A,k], outputs=result, updates=updates) power = theano.function(inputs=[A, k], outputs=result, updates=updates)
time_compile = time.time() - start_compile time_compile = time.time() - start_compile
start_exec = time.time() start_exec = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论