提交 cc2d57bf authored 作者: Cesar Laurent's avatar Cesar Laurent

Fixed as_while when truncate gradient.

上级 08ed1ba0
......@@ -1981,6 +1981,8 @@ class Scan(PureOp):
self.mintaps[self.n_mit_mot]
else:
grad_steps = inputs[0]
if self.as_while:
n_steps = outs[0].shape[0]
# Restrict the number of grad steps according to
# self.truncate_gradient
......@@ -2198,7 +2200,10 @@ class Scan(PureOp):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op
# Seqs
outer_inp_seqs = [x[grad_steps - 1::-1] for x in inputs[1:1 + self.n_seqs]]
if self.as_while:
outer_inp_seqs = [x[n_steps - 1::-1] for x in inputs[1:1 + self.n_seqs]]
else:
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = np.min(self.tap_array[idx])
if idx < self.n_mit_mot:
......@@ -2216,7 +2221,10 @@ class Scan(PureOp):
x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[grad_steps - 1::-1])
if self.as_while:
outer_inp_seqs.append(x[n_steps - 1::-1])
else:
outer_inp_seqs.append(x[::-1])
if hasattr(inputs[0].tag, 'test_value'):
# Here we tests that the new scan input sequence all have
......@@ -2536,7 +2544,7 @@ class Scan(PureOp):
outer_inp_seqs +
outer_inp_mitmot +
outer_inp_sitsot +
[grad_steps if self.as_while else inputs[0]
[n_steps if self.as_while else inputs[0]
for _ in xrange(n_nit_sot)] +
self.outer_shared(inputs) +
self.outer_non_seqs(inputs))
......@@ -2569,7 +2577,7 @@ class Scan(PureOp):
# the gradients, so that they match the size of the input
# sequences.
if self.as_while:
n_zeros = inputs[0] - grad_steps
n_zeros = inputs[0] - n_steps
shp = (n_zeros,)
if x.ndim > 1:
shp = shp + x.shape[1:]
......
......@@ -5533,21 +5533,37 @@ class TestMissingInputError(unittest.TestCase):
class TestGradUntil(unittest.TestCase):
def setUp(self):
self.x = tensor.vector(name='x')
self.until = tensor.scalar(name='until', dtype='int64')
self.seq = numpy.arange(15, dtype=theano.config.floatX)
self.numpy_output = self.seq[:7]**2
z = numpy.zeros(8, dtype=theano.config.floatX)
self.numpy_gradient = 2 * numpy.concatenate([self.seq[:7], z], axis=0)
def test_grad_until(self):
x = tensor.vector(name='x')
until = tensor.scalar(name='until', dtype='int64')
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)),
sequences=x,
non_sequences=[until])
g = theano.grad(r.sum(), x)
f = theano.function([x, until], [r, g])
x_num = numpy.arange(15, dtype=theano.config.floatX)
theano_sequence, theano_gradient = f(x_num, 5)
sequences=self.x,
non_sequences=[self.until])
g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.until], [r, g])
theano_output, theano_gradient = f(self.seq, 5)
numpy_sequence = numpy.arange(7, dtype=theano.config.floatX)
z = numpy.zeros(8, dtype=theano.config.floatX)
numpy_gradient = numpy.concatenate([numpy_sequence, z], axis=0)
utt.assert_allclose(theano_output, self.numpy_output)
utt.assert_allclose(theano_gradient, self.numpy_gradient)
utt.assert_allclose(theano_sequence, numpy_sequence**2)
utt.assert_allclose(theano_gradient, 2 * numpy_gradient)
def test_grad_until_and_truncate(self):
n = 3
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)),
sequences=self.x,
non_sequences=[self.until],
truncate_gradient=n)
g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.until], [r, g])
theano_output, theano_gradient = f(self.seq, 5)
self.numpy_gradient[:7 - n] = 0
utt.assert_allclose(theano_output, self.numpy_output)
utt.assert_allclose(theano_gradient, self.numpy_gradient)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论