提交 50c7bbdf authored 作者: Razvan Pascanu's avatar Razvan Pascanu

check first if disconnected, afterwards if undefined

上级 90d9bb5f
......@@ -1690,36 +1690,46 @@ class Scan(PureOp):
for p, (x, t) in enumerate(
zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])):
if t == 'undefined':
if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append(
grad_undefined(self,
p + 1,
inputs[p + 1],
'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else:
gradients.append(x[::-1])
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])):
if t == 'undefined':
if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append(
grad_undefined(self,
p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs],
'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else:
gradients.append(x[::-1])
start = len(gradients)
gradients += [
grad_undefined(self, x + start, inputs[x + start],
'Shared Variable with update')
for x in xrange(self.n_shared_outs)]
node = outs[0].owner
for idx in xrange(self.n_shared_outs):
disconnected = True
connected_flags = self.connection_pattern(node)[idx+start]
for dC_dout, connected in zip(dC_douts, connected_flags):
if (not isinstance(dC_dout.type, DisconnectedType) and
connected):
disconnected = False
if disconnected:
gradients.append(DisconnectedType()())
else:
gradients.append(grad_undefined(
self, idx, inputs[idx],
'Shared Variable with update'))
start = len(gradients)
gradients += [DisconnectedType()()
for x in xrange(self.n_nit_sot)]
......@@ -1728,14 +1738,14 @@ class Scan(PureOp):
end = begin + n_sitsot_outs
for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])):
if t == 'undefined':
if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append(
grad_undefined(self,
p + begin + 1,
inputs[p + begin + 1],
'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else:
gradients.append(x[-1])
return gradients
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论