提交 c683c976 authored 作者: Frederic Bastien's avatar Frederic Bastien

Select the good dtype when using multiple output.

上级 f3ab7ed8
...@@ -214,17 +214,16 @@ class IfElse(PureOp): ...@@ -214,17 +214,16 @@ class IfElse(PureOp):
name=nw_name_f) name=nw_name_f)
# The grads can have a different dtype then the inputs. # The grads can have a different dtype then the inputs.
# As all inputs except the condition must have the same dtype, # As inputs true/false pair must have the same dtype,
# we must cast the zeros to the grad dtype and not the input dtype. # we must cast the zeros to the corresponding grad dtype
# We hope that each grads have the same dtype and none had its # and not the input dtype.
# dtype changed differently then the others. This could happen
# in theory.
dtype = grads[0].dtype
if_true = ([ins[0]] + if_true = ([ins[0]] +
grads + grads +
[theano.tensor.zeros_like(t, dtype=dtype) for t in ts]) [theano.tensor.zeros_like(t, dtype=grads[i].dtype)
for i, t in enumerate(ts)])
if_false = ([ins[0]] + if_false = ([ins[0]] +
[theano.tensor.zeros_like(f, dtype=dtype) for f in fs] + [theano.tensor.zeros_like(f, dtype=grads[i].dtype)
for i, f in enumerate(fs)] +
grads) grads)
condition = ins[0] condition = ins[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论