提交 f31a4118 authored 作者: Frederic's avatar Frederic

Fix in Choose with mixed input dtype. We where specifiing the wrong output dtype.

上级 c7723227
......@@ -5216,9 +5216,12 @@ class Choose(Op):
a = as_tensor_variable(a)
if isinstance(choices, (tuple, list)):
choice = theano.typed_list.make_list(choices)
dtype = choice.ttype.dtype
else:
choice = as_tensor_variable(choices)
return Apply(self, [a, choice], [a.type()])
o = TensorType(choice.dtype, a.broadcastable)
return Apply(self, [a, choice], [o()])
def perform(self, node, inputs, (z, )):
a = inputs[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论