提交 e2f11c78 authored 作者: Cesar Laurent's avatar Cesar Laurent

PEP8 and adressed comments.

上级 bac0fde1
import theano
import theano.tensor as T
def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
non_sequences=[], name="checkpointscan_fn",
n_steps=None, save_every_N=10):
"""
Current assumptions :
- Every sequence has the same length
......@@ -14,8 +13,8 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
- Only singly-recurrent and non-recurrent outputs are used.
No multiple recurrences.
- Only the last timestep of any output will ever be used.
"""
"""
# Standardize the format of input arguments
if not isinstance(sequences, list):
sequences = [sequences]
......@@ -31,14 +30,14 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
n_steps = n_steps
# Compute the number of steps of the inner and of the outer scan
o_n_steps = n_steps / save_every_N
o_n_steps = theano.tensor.cast(n_steps / save_every_N, 'int64')
i_n_steps = save_every_N
# Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
[s.shape[i] for i in range(1, s.ndim)], s.ndim + 1) for s in sequences]
[s.shape[i] for i in range(1, s.ndim)],
s.ndim + 1) for s in sequences]
new_nitsots = [i for i in outputs_info if i is None]
new_sitsots = [i for i in outputs_info if i is not None]
o_nonsequences = non_sequences + [i_n_steps]
def outer_step(*args):
......@@ -47,14 +46,12 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
i_sequences = list(args[:len(o_sequences)])
i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)])
i_non_sequences = list(args[-len(o_nonsequences):])
# Assemble the correct outputs_info list for the inner_scan
i_outputs_info = []
i_outputs_infos = i_prev_outputs + [None, ] * len(new_nitsots)
# Call the user-provided function with the proper arguments
results, updates = theano.scan(fn=fn,
sequences=i_sequences,
outputs_info=i_prev_outputs + [None,] * len(new_nitsots),
outputs_info=i_outputs_infos,
non_sequences=i_non_sequences[:-1],
name=name + "_inner",
n_steps=i_non_sequences[-1])
......@@ -75,8 +72,4 @@ def scan_with_checkpoints(fn, sequences=[], outputs_info=None,
n_steps=o_n_steps, allow_gc=True)
# Keep only the last timestep of every output but keep all the updates
return results, updates # TODO is it a bug?
if not isinstance(results, list):
return results[-1:], updates
else:
return [r[-1:] for r in results], updates
return results, updates
......@@ -31,7 +31,7 @@ def example1(checkpoint=True):
# compiled function that returns A**k
start_compile = time.time()
power = theano.function(inputs=[A,k], outputs=result, updates=updates)
power = theano.function(inputs=[A, k], outputs=result, updates=updates)
time_compile = time.time() - start_compile
start_exec = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论