提交 3d665bac authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5254 from Thrandis/ccw

Removed save_every_N sequence size limitiation.
...@@ -265,11 +265,9 @@ def scan(fn, ...@@ -265,11 +265,9 @@ def scan(fn,
``n_steps`` is the number of steps to iterate given as an int ``n_steps`` is the number of steps to iterate given as an int
or Theano scalar. If any of the input sequences do not have or Theano scalar. If any of the input sequences do not have
enough elements, scan will raise an error. If the *value is 0* the enough elements, scan will raise an error. If the *value is 0* the
outputs will have *0 rows*. If the value is negative, ``scan`` outputs will have *0 rows*. If n_steps is not provided, ``scan`` will
will run backwards in time. If the ``go_backwards`` flag is already figure out the amount of steps it should run given its input
set and also ``n_steps`` is negative, ``scan`` will run forward sequences. ``n_steps`` < 0 is not supported anymore.
in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
truncate_gradient truncate_gradient
``truncate_gradient`` is the number of steps to use in truncated ``truncate_gradient`` is the number of steps to use in truncated
......
...@@ -4,7 +4,8 @@ import theano ...@@ -4,7 +4,8 @@ import theano
def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
name="checkpointscan_fn", n_steps=None, save_every_N=10): name="checkpointscan_fn", n_steps=None, save_every_N=10,
padding=True):
"""Scan function that uses less memory, but is more restrictive. """Scan function that uses less memory, but is more restrictive.
In :func:`~theano.scan`, if you compute the gradient of the output In :func:`~theano.scan`, if you compute the gradient of the output
...@@ -52,19 +53,23 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], ...@@ -52,19 +53,23 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
n_steps n_steps
``n_steps`` is the number of steps to iterate given as an int ``n_steps`` is the number of steps to iterate given as an int
or Theano scalar. If any of the input sequences do not have or Theano scalar (> 0). If any of the input sequences do not have
enough elements, scan will raise an error. If the **value is 0** enough elements, scan will raise an error. If n_steps is not provided,
the outputs will have **0 rows**. If the value is negative, ``scan`` will figure out the amount of steps it should run given its
``scan`` will run backwards in time. If the ``go_backwards`` flag input sequences.
is already set and also ``n_steps`` is negative, ``scan`` will run
forward in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
save_every_N save_every_N
``save_every_N`` is the number of steps to go without storing ``save_every_N`` is the number of steps to go without storing
the computations of ``scan`` (ie they will have to be recomputed the computations of ``scan`` (ie they will have to be recomputed
during the gradient computation). during the gradient computation).
padding
If the length of the sequences is not a multiple of ``save_every_N``,
the sequences will be zero padded to make this version of ``scan``
work properly, but will also result in a memory copy. It can be
avoided by setting ``padding`` to False, but you need to make
sure the length of the sequences is a multple of ``save_every_N``.
Returns Returns
------- -------
tuple tuple
...@@ -96,16 +101,31 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], ...@@ -96,16 +101,31 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
if n_steps is None: if n_steps is None:
n_steps = sequences[0].shape[0] n_steps = sequences[0].shape[0]
# Compute the number of steps of the inner and of the outer scan # Compute the number of steps of the outer scan
o_n_steps = theano.tensor.cast(n_steps / save_every_N, 'int64') o_n_steps = theano.tensor.cast(theano.tensor.ceil(n_steps / save_every_N),
i_n_steps = save_every_N 'int64')
# Compute the number of steps of the inner scan
i_n_steps = save_every_N * theano.tensor.ones((o_n_steps,), 'int64')
mod = n_steps % save_every_N
last_n_steps = theano.tensor.switch(theano.tensor.eq(mod, 0),
save_every_N, mod)
i_n_steps = theano.tensor.set_subtensor(i_n_steps[-1], last_n_steps)
# Pad the sequences if needed
if padding:
for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N
z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype)
sequences[i] = theano.tensor.concatenate([s, z], axis=0)
# Establish the input variables of the outer scan # Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] + o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
[s.shape[i] for i in range(1, s.ndim)], [s.shape[i] for i in range(1, s.ndim)],
s.ndim + 1) for s in sequences] s.ndim + 1) for s in sequences]
o_sequences.append(i_n_steps)
new_nitsots = [i for i in outputs_info if i is None] new_nitsots = [i for i in outputs_info if i is None]
o_nonsequences = non_sequences + [i_n_steps] o_nonsequences = non_sequences
def outer_step(*args): def outer_step(*args):
# Separate the received arguments into their respective (seq, outputs # Separate the received arguments into their respective (seq, outputs
...@@ -117,11 +137,11 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], ...@@ -117,11 +137,11 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
# Call the user-provided function with the proper arguments # Call the user-provided function with the proper arguments
results, updates = theano.scan(fn=fn, results, updates = theano.scan(fn=fn,
sequences=i_sequences, sequences=i_sequences[:-1],
outputs_info=i_outputs_infos, outputs_info=i_outputs_infos,
non_sequences=i_non_sequences[:-1], non_sequences=i_non_sequences,
name=name + "_inner", name=name + "_inner",
n_steps=i_non_sequences[-1]) n_steps=i_sequences[-1])
if not isinstance(results, list): if not isinstance(results, list):
results = [results] results = [results]
......
...@@ -38,14 +38,14 @@ class TestScanCheckpoint(unittest.TestCase): ...@@ -38,14 +38,14 @@ class TestScanCheckpoint(unittest.TestCase):
"""Test forward computation of A**k.""" """Test forward computation of A**k."""
f = theano.function(inputs=[self.A, self.k], f = theano.function(inputs=[self.A, self.k],
outputs=[self.result, self.result_check]) outputs=[self.result, self.result_check])
out, out_check = f(range(10), 100) out, out_check = f(range(10), 101)
assert numpy.allclose(out, out_check) assert numpy.allclose(out, out_check)
def test_backward_pass(self): def test_backward_pass(self):
"""Test gradient computation of A**k.""" """Test gradient computation of A**k."""
f = theano.function(inputs=[self.A, self.k], f = theano.function(inputs=[self.A, self.k],
outputs=[self.grad_A, self.grad_A_check]) outputs=[self.grad_A, self.grad_A_check])
out, out_check = f(range(10), 100) out, out_check = f(range(10), 101)
assert numpy.allclose(out, out_check) assert numpy.allclose(out, out_check)
@unittest.skipUnless(PYGPU_AVAILABLE, 'Requires pygpu.') @unittest.skipUnless(PYGPU_AVAILABLE, 'Requires pygpu.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论