提交 3965d55a authored 作者: Cesar Laurent's avatar Cesar Laurent

Fixed n_steps and added testing.

上级 56ab770a
...@@ -53,13 +53,10 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], ...@@ -53,13 +53,10 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
n_steps n_steps
``n_steps`` is the number of steps to iterate given as an int ``n_steps`` is the number of steps to iterate given as an int
or Theano scalar. If any of the input sequences do not have or Theano scalar (> 0). If any of the input sequences do not have
enough elements, scan will raise an error. If the **value is 0** enough elements, scan will raise an error. If n_steps is not provided,
the outputs will have **0 rows**. If the value is negative, ``scan`` will figure out the amount of steps it should run given its
``scan`` will run backwards in time. If the ``go_backwards`` flag input sequences.
is already set and also ``n_steps`` is negative, ``scan`` will run
forward in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
save_every_N save_every_N
``save_every_N`` is the number of steps to go without storing ``save_every_N`` is the number of steps to go without storing
...@@ -108,8 +105,10 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], ...@@ -108,8 +105,10 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
o_n_steps = theano.tensor.cast(theano.tensor.ceil(n_steps / save_every_N), o_n_steps = theano.tensor.cast(theano.tensor.ceil(n_steps / save_every_N),
'int64') 'int64')
i_n_steps = save_every_N * theano.tensor.ones((o_n_steps,), 'int64') i_n_steps = save_every_N * theano.tensor.ones((o_n_steps,), 'int64')
i_n_steps = theano.tensor.inc_subtensor(i_n_steps[-1], mod = n_steps % save_every_N
- n_steps % save_every_N) sign = theano.tensor.sgn(mod)
rest = (1 - sign) * save_every_N + sign * mod
i_n_steps = theano.tensor.set_subtensor(i_n_steps[-1], rest)
# Pad the sequences if needed # Pad the sequences if needed
if not no_padding: if not no_padding:
......
...@@ -38,14 +38,14 @@ class TestScanCheckpoint(unittest.TestCase): ...@@ -38,14 +38,14 @@ class TestScanCheckpoint(unittest.TestCase):
"""Test forward computation of A**k.""" """Test forward computation of A**k."""
f = theano.function(inputs=[self.A, self.k], f = theano.function(inputs=[self.A, self.k],
outputs=[self.result, self.result_check]) outputs=[self.result, self.result_check])
out, out_check = f(range(10), 100) out, out_check = f(range(10), 101)
assert numpy.allclose(out, out_check) assert numpy.allclose(out, out_check)
def test_backward_pass(self): def test_backward_pass(self):
"""Test gradient computation of A**k.""" """Test gradient computation of A**k."""
f = theano.function(inputs=[self.A, self.k], f = theano.function(inputs=[self.A, self.k],
outputs=[self.grad_A, self.grad_A_check]) outputs=[self.grad_A, self.grad_A_check])
out, out_check = f(range(10), 100) out, out_check = f(range(10), 101)
assert numpy.allclose(out, out_check) assert numpy.allclose(out, out_check)
@unittest.skipUnless(PYGPU_AVAILABLE, 'Requires pygpu.') @unittest.skipUnless(PYGPU_AVAILABLE, 'Requires pygpu.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论