提交 4838cea1 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

instead of a connection matrix

上级 d1258bcf
......@@ -1291,6 +1291,7 @@ class Scan(PureOp):
if x in diff_inputs])
return [gmp.get(p, None) for p in diff_inputs]
dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs]
dC_dXts = []
Xts = []
for idx, Xt in enumerate(diff_outputs):
......@@ -1325,13 +1326,11 @@ class Scan(PureOp):
elif _dC_dinps_t[jdx]:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients
disconnected_dC_dinps_t = []
for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]:
dC_dinps_t[dx] = tensor.zeros_like(diff_inputs[dx])
disconnected_dC_dinps_t.append(True)
else:
disconnected_dC_dinps_t.append(False)
disconnected_dC_dinps_t[dx] = False
for Xt, Xt_placeholder in zip(
diff_outputs[self.n_mit_mot_outs:],
Xts):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论