提交 3d665bac authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5254 from Thrandis/ccw

Removed save_every_N sequence size limitiation.
......@@ -265,11 +265,9 @@ def scan(fn,
``n_steps`` is the number of steps to iterate given as an int
or Theano scalar. If any of the input sequences do not have
enough elements, scan will raise an error. If the *value is 0* the
outputs will have *0 rows*. If the value is negative, ``scan``
will run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
outputs will have *0 rows*. If n_steps is not provided, ``scan`` will
figure out the amount of steps it should run given its input
sequences. ``n_steps`` < 0 is not supported anymore.
truncate_gradient
``truncate_gradient`` is the number of steps to use in truncated
......
......@@ -4,7 +4,8 @@ import theano
def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
name="checkpointscan_fn", n_steps=None, save_every_N=10):
name="checkpointscan_fn", n_steps=None, save_every_N=10,
padding=True):
"""Scan function that uses less memory, but is more restrictive.
In :func:`~theano.scan`, if you compute the gradient of the output
......@@ -52,19 +53,23 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
n_steps
``n_steps`` is the number of steps to iterate given as an int
or Theano scalar. If any of the input sequences do not have
enough elements, scan will raise an error. If the **value is 0**
the outputs will have **0 rows**. If the value is negative,
``scan`` will run backwards in time. If the ``go_backwards`` flag
is already set and also ``n_steps`` is negative, ``scan`` will run
forward in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
or Theano scalar (> 0). If any of the input sequences do not have
enough elements, scan will raise an error. If n_steps is not provided,
``scan`` will figure out the amount of steps it should run given its
input sequences.
save_every_N
``save_every_N`` is the number of steps to go without storing
the computations of ``scan`` (ie they will have to be recomputed
during the gradient computation).
padding
If the length of the sequences is not a multiple of ``save_every_N``,
the sequences will be zero padded to make this version of ``scan``
work properly, but will also result in a memory copy. It can be
avoided by setting ``padding`` to False, but you need to make
sure the length of the sequences is a multple of ``save_every_N``.
Returns
-------
tuple
......@@ -96,16 +101,31 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
if n_steps is None:
n_steps = sequences[0].shape[0]
# Compute the number of steps of the inner and of the outer scan
o_n_steps = theano.tensor.cast(n_steps / save_every_N, 'int64')
i_n_steps = save_every_N
# Compute the number of steps of the outer scan
o_n_steps = theano.tensor.cast(theano.tensor.ceil(n_steps / save_every_N),
'int64')
# Compute the number of steps of the inner scan
i_n_steps = save_every_N * theano.tensor.ones((o_n_steps,), 'int64')
mod = n_steps % save_every_N
last_n_steps = theano.tensor.switch(theano.tensor.eq(mod, 0),
save_every_N, mod)
i_n_steps = theano.tensor.set_subtensor(i_n_steps[-1], last_n_steps)
# Pad the sequences if needed
if padding:
for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N
z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype)
sequences[i] = theano.tensor.concatenate([s, z], axis=0)
# Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
[s.shape[i] for i in range(1, s.ndim)],
s.ndim + 1) for s in sequences]
o_sequences.append(i_n_steps)
new_nitsots = [i for i in outputs_info if i is None]
o_nonsequences = non_sequences + [i_n_steps]
o_nonsequences = non_sequences
def outer_step(*args):
# Separate the received arguments into their respective (seq, outputs
......@@ -117,11 +137,11 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
# Call the user-provided function with the proper arguments
results, updates = theano.scan(fn=fn,
sequences=i_sequences,
sequences=i_sequences[:-1],
outputs_info=i_outputs_infos,
non_sequences=i_non_sequences[:-1],
non_sequences=i_non_sequences,
name=name + "_inner",
n_steps=i_non_sequences[-1])
n_steps=i_sequences[-1])
if not isinstance(results, list):
results = [results]
......
......@@ -38,14 +38,14 @@ class TestScanCheckpoint(unittest.TestCase):
"""Test forward computation of A**k."""
f = theano.function(inputs=[self.A, self.k],
outputs=[self.result, self.result_check])
out, out_check = f(range(10), 100)
out, out_check = f(range(10), 101)
assert numpy.allclose(out, out_check)
def test_backward_pass(self):
"""Test gradient computation of A**k."""
f = theano.function(inputs=[self.A, self.k],
outputs=[self.grad_A, self.grad_A_check])
out, out_check = f(range(10), 100)
out, out_check = f(range(10), 101)
assert numpy.allclose(out, out_check)
@unittest.skipUnless(PYGPU_AVAILABLE, 'Requires pygpu.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论