提交 1704b94d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix grad to not expect same dtype of the grad as input

上级 c8216133
......@@ -1291,12 +1291,6 @@ class Scan(PureOp):
offset = len(args) - len(other_args) - pos
# 7.2. generate variables to represent previous steps of g_outs
for idx, diff_in in enumerate(diff_inputs):
prev_gfn_out = safe_new(diff_in)
if hasattr(diff_in, 'name') and diff_in.name:
prev_gfn_out.name = 'g_prev_' + diff_in.name
else:
prev_gfn_out.name = 'g_prev_' + str(idx)
prev_inner_gfn_outs.append(prev_gfn_out)
if idx < pos:
zeros_like_diff_ins.append(tensor.zeros_like(diff_in))
else:
......@@ -1328,10 +1322,7 @@ class Scan(PureOp):
grad_outs = compute_gradient(out, _g_out)
if not inner_gfn_outs:
for idx, gfn_out in enumerate(grad_outs):
if idx >= self.n_seqs:
inner_gfn_outs.append(prev_inner_gfn_outs[idx])
else:
inner_gfn_outs.append(None)
inner_gfn_outs.append(None)
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
......@@ -1344,6 +1335,10 @@ class Scan(PureOp):
else:
inner_gfn_outs[i] = x
prev_inner_gfn_outs = [x.type() for x in inner_gfn_outs]
for dx in xrange(self.n_seqs, len(inner_gfn_outs)):
inner_gfn_outs[dx] = inner_gfn_outs[dx] + \
prev_inner_gfn_outs[dx]
## 8. Mask the outputs that are not differentiable
# backwards pass
for i in xrange(len(inner_gfn_outs)):
......@@ -1493,8 +1488,9 @@ class Scan(PureOp):
n_sitsot_outs = len(prev_inner_gfn_outs[offset:])
scan_sitsot_ins = prev_inner_gfn_outs[offset:]
scan_sitsot_init = []
for x in zeros_like_diff_ins[offset:]:
shapes = [x.shape[i] for i in xrange(x.ndim)]
for x,y in zip(prev_inner_gfn_outs[offset:],
zeros_like_diff_ins[offset:]):
shapes = [y.shape[i] for i in xrange(x.ndim)]
empty = tensor.zeros([do_steps + 1] + shapes,
dtype=x.dtype)
scan_sitsot_init.append(empty)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论