提交 0eab5957 authored 作者: Cesar Laurent's avatar Cesar Laurent

Correct gradient of scan until.

上级 72823c46
...@@ -2198,7 +2198,7 @@ class Scan(PureOp): ...@@ -2198,7 +2198,7 @@ class Scan(PureOp):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1 dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
# Seqs # Seqs
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[:grad_steps][::-1] for x in inputs[1:1 + self.n_seqs]]
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = np.min(self.tap_array[idx]) mintap = np.min(self.tap_array[idx])
if idx < self.n_mit_mot: if idx < self.n_mit_mot:
...@@ -2216,7 +2216,7 @@ class Scan(PureOp): ...@@ -2216,7 +2216,7 @@ class Scan(PureOp):
x[:-1][::-1] for x in self.outer_sitsot_outs(outs)] x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts): for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[::-1]) outer_inp_seqs.append(x[:grad_steps][::-1])
if hasattr(inputs[0].tag, 'test_value'): if hasattr(inputs[0].tag, 'test_value'):
# Here we tests that the new scan input sequence all have # Here we tests that the new scan input sequence all have
...@@ -2229,15 +2229,15 @@ class Scan(PureOp): ...@@ -2229,15 +2229,15 @@ class Scan(PureOp):
mintap = np.min(taps) mintap = np.min(taps)
if hasattr(x[::-1][:mintap], 'test_value'): if hasattr(x[::-1][:mintap], 'test_value'):
assert (x[::-1][:mintap].tag.test_value.shape[0] == assert (x[::-1][:mintap].tag.test_value.shape[0] ==
inputs[0].tag.test_value) grad_steps.tag.test_value)
for x in self.outer_sitsot_outs(outs): for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'): if hasattr(x[::-1][:-1].tag, 'test_value'):
assert (x[::-1][:-1].tag.test_value.shape[0] == assert (x[::-1][:-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value) grad_steps.tag.test_value)
for x in self.outer_nitsot_outs(outs): for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'): if hasattr(x[::-1].tag, 'test_value'):
assert (x[::-1].tag.test_value.shape[0] == assert (x[:grad_steps][::-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value) grad_steps.tag.test_value)
outer_inp_seqs += [x[::-1][:np.min(taps)] outer_inp_seqs += [x[::-1][:np.min(taps)]
for taps, x in zip(self.mitsot_taps(), for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs))] self.outer_mitsot_outs(outs))]
...@@ -2536,7 +2536,7 @@ class Scan(PureOp): ...@@ -2536,7 +2536,7 @@ class Scan(PureOp):
outer_inp_seqs + outer_inp_seqs +
outer_inp_mitmot + outer_inp_mitmot +
outer_inp_sitsot + outer_inp_sitsot +
[inputs[0] for _ in xrange(n_nit_sot)] + [grad_steps for _ in xrange(n_nit_sot)] +
self.outer_shared(inputs) + self.outer_shared(inputs) +
self.outer_non_seqs(inputs)) self.outer_non_seqs(inputs))
...@@ -2564,6 +2564,18 @@ class Scan(PureOp): ...@@ -2564,6 +2564,18 @@ class Scan(PureOp):
zip(outputs[offset:offset + self.n_seqs], zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])): type_outs[offset:offset + self.n_seqs])):
if t == 'connected': if t == 'connected':
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
if self.as_while:
n_zeros = inputs[0] - grad_steps
shp = (n_zeros,)
if x.ndim > 1:
shp = shp + x.shape[1:]
z = tensor.zeros(shp, dtype=x.dtype)
x = tensor.concatenate([x[::-1], z], axis=0)
gradients.append(x)
else:
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == 'disconnected': elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
...@@ -2580,6 +2592,18 @@ class Scan(PureOp): ...@@ -2580,6 +2592,18 @@ class Scan(PureOp):
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])): for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
if t == 'connected': if t == 'connected':
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
if self.as_while:
n_zeros = inputs[0] - grad_steps
shp = (n_zeros,)
if x.ndim > 1:
shp = shp + x.shape[1:]
z = tensor.zeros(shp, dtype=x.dtype)
x = tensor.concatenate([x[::-1], z], axis=0)
gradients.append(x)
else:
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == 'disconnected': elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论