提交 ebaef5af authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5775 from Thrandis/ccw

Correct gradient of scan until.
...@@ -1955,6 +1955,8 @@ class Scan(PureOp): ...@@ -1955,6 +1955,8 @@ class Scan(PureOp):
self.mintaps[self.n_mit_mot] self.mintaps[self.n_mit_mot]
else: else:
grad_steps = inputs[0] grad_steps = inputs[0]
if self.as_while:
n_steps = outs[0].shape[0]
# Restrict the number of grad steps according to # Restrict the number of grad steps according to
# self.truncate_gradient # self.truncate_gradient
...@@ -2172,6 +2174,10 @@ class Scan(PureOp): ...@@ -2172,6 +2174,10 @@ class Scan(PureOp):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1 dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
# Seqs # Seqs
if self.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs = [x[n_steps - 1::-1] for x in inputs[1:1 + self.n_seqs]]
else:
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = np.min(self.tap_array[idx]) mintap = np.min(self.tap_array[idx])
...@@ -2190,6 +2196,10 @@ class Scan(PureOp): ...@@ -2190,6 +2196,10 @@ class Scan(PureOp):
x[:-1][::-1] for x in self.outer_sitsot_outs(outs)] x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts): for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
if self.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs.append(x[n_steps - 1::-1])
else:
outer_inp_seqs.append(x[::-1]) outer_inp_seqs.append(x[::-1])
if hasattr(inputs[0].tag, 'test_value'): if hasattr(inputs[0].tag, 'test_value'):
...@@ -2198,20 +2208,25 @@ class Scan(PureOp): ...@@ -2198,20 +2208,25 @@ class Scan(PureOp):
# fct add and we want to keep it for all Scan op. This is # fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test # used in T_Scan.test_grad_multiple_outs_taps to test
# that. # that.
if self.as_while:
n = n_steps.tag.test_value
else:
n = inputs[0].tag.test_value
for taps, x in zip(self.mitsot_taps(), for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs)): self.outer_mitsot_outs(outs)):
mintap = np.min(taps) mintap = np.min(taps)
if hasattr(x[::-1][:mintap], 'test_value'): if hasattr(x[::-1][:mintap], 'test_value'):
assert (x[::-1][:mintap].tag.test_value.shape[0] == assert (x[::-1][:mintap].tag.test_value.shape[0] == n)
inputs[0].tag.test_value)
for x in self.outer_sitsot_outs(outs): for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'): if hasattr(x[::-1][:-1].tag, 'test_value'):
assert (x[::-1][:-1].tag.test_value.shape[0] == assert (x[::-1][:-1].tag.test_value.shape[0] == n)
inputs[0].tag.test_value)
for x in self.outer_nitsot_outs(outs): for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'): if hasattr(x[::-1].tag, 'test_value'):
assert (x[::-1].tag.test_value.shape[0] == if self.as_while:
inputs[0].tag.test_value) assert (x[n_steps - 1::-1].tag.test_value.shape[0] ==
n)
else:
assert (x[::-1].tag.test_value.shape[0] == n)
outer_inp_seqs += [x[::-1][:np.min(taps)] outer_inp_seqs += [x[::-1][:np.min(taps)]
for taps, x in zip(self.mitsot_taps(), for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs))] self.outer_mitsot_outs(outs))]
...@@ -2510,7 +2525,8 @@ class Scan(PureOp): ...@@ -2510,7 +2525,8 @@ class Scan(PureOp):
outer_inp_seqs + outer_inp_seqs +
outer_inp_mitmot + outer_inp_mitmot +
outer_inp_sitsot + outer_inp_sitsot +
[inputs[0] for _ in xrange(n_nit_sot)] + [n_steps if self.as_while else inputs[0]
for _ in xrange(n_nit_sot)] +
self.outer_shared(inputs) + self.outer_shared(inputs) +
self.outer_non_seqs(inputs)) self.outer_non_seqs(inputs))
...@@ -2538,6 +2554,18 @@ class Scan(PureOp): ...@@ -2538,6 +2554,18 @@ class Scan(PureOp):
zip(outputs[offset:offset + self.n_seqs], zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])): type_outs[offset:offset + self.n_seqs])):
if t == 'connected': if t == 'connected':
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
if self.as_while:
n_zeros = inputs[0] - n_steps
shp = (n_zeros,)
if x.ndim > 1:
shp = shp + x.shape[1:]
z = tensor.zeros(shp, dtype=x.dtype)
x = tensor.concatenate([x[::-1], z], axis=0)
gradients.append(x)
else:
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == 'disconnected': elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
...@@ -2554,6 +2582,18 @@ class Scan(PureOp): ...@@ -2554,6 +2582,18 @@ class Scan(PureOp):
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])): for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
if t == 'connected': if t == 'connected':
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
if self.as_while:
n_zeros = inputs[0] - grad_steps
shp = (n_zeros,)
if x.ndim > 1:
shp = shp + x.shape[1:]
z = tensor.zeros(shp, dtype=x.dtype)
x = tensor.concatenate([x[::-1], z], axis=0)
gradients.append(x)
else:
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == 'disconnected': elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
...@@ -2585,7 +2625,7 @@ class Scan(PureOp): ...@@ -2585,7 +2625,7 @@ class Scan(PureOp):
start = len(gradients) start = len(gradients)
gradients += [DisconnectedType()() gradients += [DisconnectedType()()
for x in xrange(self.n_nit_sot)] for _ in xrange(self.n_nit_sot)]
begin = end begin = end
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
......
...@@ -5465,3 +5465,57 @@ class TestMissingInputError(unittest.TestCase): ...@@ -5465,3 +5465,57 @@ class TestMissingInputError(unittest.TestCase):
_, updates = theano.scan(count_up, n_steps=20) _, updates = theano.scan(count_up, n_steps=20)
func = theano.function(inputs=[inc], outputs=[], updates=updates) func = theano.function(inputs=[inc], outputs=[], updates=updates)
class TestGradUntil(unittest.TestCase):
def setUp(self):
self.x = tensor.vector(name='x')
self.threshold = tensor.scalar(name='threshold', dtype='int64')
self.seq = np.arange(15, dtype=theano.config.floatX)
self.numpy_output = self.seq[:7]**2
z = np.zeros(8, dtype=theano.config.floatX)
self.numpy_gradient = 2 * np.concatenate([self.seq[:7], z], axis=0)
def test_grad_until(self):
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)),
sequences=self.x,
non_sequences=[self.threshold])
g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.threshold], [r, g])
theano_output, theano_gradient = f(self.seq, 5)
utt.assert_allclose(theano_output, self.numpy_output)
utt.assert_allclose(theano_gradient, self.numpy_gradient)
def test_grad_until_and_truncate(self):
n = 3
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)),
sequences=self.x,
non_sequences=[self.threshold],
truncate_gradient=n)
g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.threshold], [r, g])
theano_output, theano_gradient = f(self.seq, 5)
self.numpy_gradient[:7 - n] = 0
utt.assert_allclose(theano_output, self.numpy_output)
utt.assert_allclose(theano_gradient, self.numpy_gradient)
def test_grad_until_and_truncate_sequence_taps(self):
n = 3
r, _ = theano.scan(lambda x, y, u: (x * y,
theano.scan_module.until(y > u)),
sequences=dict(input=self.x, taps=[-2, 0]),
non_sequences=[self.threshold],
truncate_gradient=n)
g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.threshold], [r, g])
theano_output, theano_gradient = f(self.seq, 6)
# Gradient computed by hand:
numpy_grad = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
numpy_grad = numpy_grad.astype(theano.config.floatX)
utt.assert_allclose(theano_gradient, numpy_grad)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论