提交 1a6b4072 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Detect Null gradient wrt non-sequences in scan

上级 e34f6e59
...@@ -1875,13 +1875,27 @@ class Scan(PureOp): ...@@ -1875,13 +1875,27 @@ class Scan(PureOp):
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
type_outs.append('connected') type_outs.append('connected')
inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs:] inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs:]
outer_inp_sitsot = [ outer_inp_sitsot = []
for _idx, y in enumerate(inner_inp_sitsot):
x = self.outer_non_seqs(inputs)[_idx]
if isinstance(y.type, NullType):
# Cannot use dC_dXtm1s.dtype, so we use floatX instead.
outer_inp_sitsot.append(
tensor.zeros([grad_steps + 1] +
[x.shape[i] for i in xrange(x.ndim)],
dtype=theano.config.floatX))
# replace y by a zero tensor of the right shape
inner_inp_sitsot[_idx] = tensor.zeros(
diff_inputs[ins_pos + _idx].shape,
dtype=theano.config.floatX)
else:
outer_inp_sitsot.append(
tensor.zeros([grad_steps + 1] + tensor.zeros([grad_steps + 1] +
[x.shape[i] for i in xrange(x.ndim)], [x.shape[i] for i in xrange(x.ndim)],
dtype=y.dtype) dtype=y.dtype))
for y, x in zip(inner_inp_sitsot,
self.outer_non_seqs(inputs))]
n_sitsot_outs = len(outer_inp_sitsot) n_sitsot_outs = len(outer_inp_sitsot)
new_tap_array = mitmot_inp_taps + [[-1] for k in new_tap_array = mitmot_inp_taps + [[-1] for k in
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论