提交 25ca839d authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: upcast inputs if they are of mixed types

上级 8cd678da
......@@ -3399,9 +3399,11 @@ class BatchedDot(Op):
'theano.tensor.batched_dot instead.' % inputs[1].ndim)
dtype = scal.upcast(*[input.type.dtype for input in inputs])
# upcast inputs to common dtype if needed
upcasted_inputs = [cast(input, dtype) for input in inputs]
broadcastable = (inputs[0].type.broadcastable[:-1] +
inputs[1].type.broadcastable[2:])
return Apply(self, inputs, [tensor(dtype, broadcastable)])
return Apply(self, upcasted_inputs, [tensor(dtype, broadcastable)])
def perform(self, node, inp, out):
x, y = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论