提交 08ed1ba0 authored 作者: Cesar Laurent's avatar Cesar Laurent

Fixed indexing and added a test.

上级 0eab5957
...@@ -2198,7 +2198,7 @@ class Scan(PureOp): ...@@ -2198,7 +2198,7 @@ class Scan(PureOp):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1 dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
# Seqs # Seqs
outer_inp_seqs = [x[:grad_steps][::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[grad_steps - 1::-1] for x in inputs[1:1 + self.n_seqs]]
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = np.min(self.tap_array[idx]) mintap = np.min(self.tap_array[idx])
if idx < self.n_mit_mot: if idx < self.n_mit_mot:
...@@ -2216,7 +2216,7 @@ class Scan(PureOp): ...@@ -2216,7 +2216,7 @@ class Scan(PureOp):
x[:-1][::-1] for x in self.outer_sitsot_outs(outs)] x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts): for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[:grad_steps][::-1]) outer_inp_seqs.append(x[grad_steps - 1::-1])
if hasattr(inputs[0].tag, 'test_value'): if hasattr(inputs[0].tag, 'test_value'):
# Here we tests that the new scan input sequence all have # Here we tests that the new scan input sequence all have
...@@ -2236,7 +2236,7 @@ class Scan(PureOp): ...@@ -2236,7 +2236,7 @@ class Scan(PureOp):
grad_steps.tag.test_value) grad_steps.tag.test_value)
for x in self.outer_nitsot_outs(outs): for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'): if hasattr(x[::-1].tag, 'test_value'):
assert (x[:grad_steps][::-1].tag.test_value.shape[0] == assert (x[grad_steps - 1::-1].tag.test_value.shape[0] ==
grad_steps.tag.test_value) grad_steps.tag.test_value)
outer_inp_seqs += [x[::-1][:np.min(taps)] outer_inp_seqs += [x[::-1][:np.min(taps)]
for taps, x in zip(self.mitsot_taps(), for taps, x in zip(self.mitsot_taps(),
...@@ -2536,7 +2536,8 @@ class Scan(PureOp): ...@@ -2536,7 +2536,8 @@ class Scan(PureOp):
outer_inp_seqs + outer_inp_seqs +
outer_inp_mitmot + outer_inp_mitmot +
outer_inp_sitsot + outer_inp_sitsot +
[grad_steps for _ in xrange(n_nit_sot)] + [grad_steps if self.as_while else inputs[0]
for _ in xrange(n_nit_sot)] +
self.outer_shared(inputs) + self.outer_shared(inputs) +
self.outer_non_seqs(inputs)) self.outer_non_seqs(inputs))
...@@ -2635,7 +2636,7 @@ class Scan(PureOp): ...@@ -2635,7 +2636,7 @@ class Scan(PureOp):
start = len(gradients) start = len(gradients)
gradients += [DisconnectedType()() gradients += [DisconnectedType()()
for x in xrange(self.n_nit_sot)] for _ in xrange(self.n_nit_sot)]
begin = end begin = end
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
......
...@@ -5529,3 +5529,25 @@ class TestMissingInputError(unittest.TestCase): ...@@ -5529,3 +5529,25 @@ class TestMissingInputError(unittest.TestCase):
_, updates = theano.scan(count_up, n_steps=20) _, updates = theano.scan(count_up, n_steps=20)
func = theano.function(inputs=[inc], outputs=[], updates=updates) func = theano.function(inputs=[inc], outputs=[], updates=updates)
class TestGradUntil(unittest.TestCase):
def test_grad_until(self):
x = tensor.vector(name='x')
until = tensor.scalar(name='until', dtype='int64')
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)),
sequences=x,
non_sequences=[until])
g = theano.grad(r.sum(), x)
f = theano.function([x, until], [r, g])
x_num = numpy.arange(15, dtype=theano.config.floatX)
theano_sequence, theano_gradient = f(x_num, 5)
numpy_sequence = numpy.arange(7, dtype=theano.config.floatX)
z = numpy.zeros(8, dtype=theano.config.floatX)
numpy_gradient = numpy.concatenate([numpy_sequence, z], axis=0)
utt.assert_allclose(theano_sequence, numpy_sequence**2)
utt.assert_allclose(theano_gradient, 2 * numpy_gradient)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论