提交 41b83184 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use zeros when the output is an int, even if dtypes match

上级 8a07ee00
......@@ -788,11 +788,10 @@ def _populate_grad_dict(var_to_node_to_idx,
for o, og in zip(node.outputs, output_grads):
o_dt = getattr(o.type, 'dtype', None)
og_dt = getattr(og.type, 'dtype', None)
if o_dt and og_dt and o_dt != og_dt:
if o_dt in theano.tensor.float_dtypes:
new_output_grads.append(og.astype(o_dt))
else:
new_output_grads.append(o.zeros_like())
if og_dt and o_dt in theano.tensor.discrete_dtypes:
new_output_grads.append(o.zeros_like())
elif o_dt and og_dt and o_dt != og_dt:
new_output_grads.append(og.astype(o_dt))
else:
new_output_grads.append(og)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论