提交 50c7bbdf authored 作者: Razvan Pascanu's avatar Razvan Pascanu

check first if disconnected, afterwards if undefined

上级 90d9bb5f
...@@ -1690,36 +1690,46 @@ class Scan(PureOp): ...@@ -1690,36 +1690,46 @@ class Scan(PureOp):
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[offset:offset + self.n_seqs], zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])): type_outs[offset:offset + self.n_seqs])):
if t == 'undefined': if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + 1, p + 1,
inputs[p + 1], inputs[p + 1],
'Depends on a shared variable')) 'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else: else:
gradients.append(x[::-1]) gradients.append(x[::-1])
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])): zip(outputs[:end], type_outs[:end])):
if t == 'undefined': if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + 1 + self.n_seqs, p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs], inputs[p + 1 + self.n_seqs],
'Depends on a shared variable')) 'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else: else:
gradients.append(x[::-1]) gradients.append(x[::-1])
start = len(gradients) start = len(gradients)
gradients += [ node = outs[0].owner
grad_undefined(self, x + start, inputs[x + start], for idx in xrange(self.n_shared_outs):
'Shared Variable with update') disconnected = True
for x in xrange(self.n_shared_outs)] connected_flags = self.connection_pattern(node)[idx+start]
for dC_dout, connected in zip(dC_douts, connected_flags):
if (not isinstance(dC_dout.type, DisconnectedType) and
connected):
disconnected = False
if disconnected:
gradients.append(DisconnectedType()())
else:
gradients.append(grad_undefined(
self, idx, inputs[idx],
'Shared Variable with update'))
start = len(gradients) start = len(gradients)
gradients += [DisconnectedType()() gradients += [DisconnectedType()()
for x in xrange(self.n_nit_sot)] for x in xrange(self.n_nit_sot)]
...@@ -1728,14 +1738,14 @@ class Scan(PureOp): ...@@ -1728,14 +1738,14 @@ class Scan(PureOp):
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])): zip(outputs[begin:end], type_outs[begin:end])):
if t == 'undefined': if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + begin + 1, p + begin + 1,
inputs[p + begin + 1], inputs[p + begin + 1],
'Depends on a shared variable')) 'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else: else:
gradients.append(x[-1]) gradients.append(x[-1])
return gradients return gradients
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论