提交 56ab770a authored 作者: Cesar Laurent's avatar Cesar Laurent

Removed save_every_N sequence size limitiation.

上级 1db72747
......@@ -4,7 +4,8 @@ import theano
def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
name="checkpointscan_fn", n_steps=None, save_every_N=10):
name="checkpointscan_fn", n_steps=None, save_every_N=10,
no_padding=False):
"""Scan function that uses less memory, but is more restrictive.
In :func:`~theano.scan`, if you compute the gradient of the output
......@@ -65,6 +66,13 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
the computations of ``scan`` (ie they will have to be recomputed
during the gradient computation).
no_padding
If the length of the sequences is not a multiple of ``save_every_N``,
the sequences will be zero padded to make this version of ``scan``
work properly, but will also result in a memory copy. It can be
avoided by setting ``no_padding`` to True, but you need to make
sure the length of the sequences is a multple of ``save_every_N``.
Returns
-------
tuple
......@@ -96,16 +104,27 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
if n_steps is None:
n_steps = sequences[0].shape[0]
# Compute the number of steps of the inner and of the outer scan
o_n_steps = theano.tensor.cast(n_steps / save_every_N, 'int64')
i_n_steps = save_every_N
# Compute the number of steps of the outer scan
o_n_steps = theano.tensor.cast(theano.tensor.ceil(n_steps / save_every_N),
'int64')
i_n_steps = save_every_N * theano.tensor.ones((o_n_steps,), 'int64')
i_n_steps = theano.tensor.inc_subtensor(i_n_steps[-1],
- n_steps % save_every_N)
# Pad the sequences if needed
if not no_padding:
for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N
z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype)
sequences[i] = theano.tensor.concatenate([s, z])
# Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
[s.shape[i] for i in range(1, s.ndim)],
s.ndim + 1) for s in sequences]
o_sequences.append(i_n_steps)
new_nitsots = [i for i in outputs_info if i is None]
o_nonsequences = non_sequences + [i_n_steps]
o_nonsequences = non_sequences
def outer_step(*args):
# Separate the received arguments into their respective (seq, outputs
......@@ -117,11 +136,11 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
# Call the user-provided function with the proper arguments
results, updates = theano.scan(fn=fn,
sequences=i_sequences,
sequences=i_sequences[:-1],
outputs_info=i_outputs_infos,
non_sequences=i_non_sequences[:-1],
non_sequences=i_non_sequences,
name=name + "_inner",
n_steps=i_non_sequences[-1])
n_steps=i_sequences[-1])
if not isinstance(results, list):
results = [results]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论