提交 41b83184 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use zeros when the output is an int, even if dtypes match

上级 8a07ee00
...@@ -788,11 +788,10 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -788,11 +788,10 @@ def _populate_grad_dict(var_to_node_to_idx,
for o, og in zip(node.outputs, output_grads): for o, og in zip(node.outputs, output_grads):
o_dt = getattr(o.type, 'dtype', None) o_dt = getattr(o.type, 'dtype', None)
og_dt = getattr(og.type, 'dtype', None) og_dt = getattr(og.type, 'dtype', None)
if o_dt and og_dt and o_dt != og_dt: if og_dt and o_dt in theano.tensor.discrete_dtypes:
if o_dt in theano.tensor.float_dtypes: new_output_grads.append(o.zeros_like())
new_output_grads.append(og.astype(o_dt)) elif o_dt and og_dt and o_dt != og_dt:
else: new_output_grads.append(og.astype(o_dt))
new_output_grads.append(o.zeros_like())
else: else:
new_output_grads.append(og) new_output_grads.append(og)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论