提交 8f65b441 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: leave first dimension broadcastable if it was broadcastable for either of the inputs

上级 e527643a
......@@ -3401,7 +3401,9 @@ class BatchedDot(Op):
dtype = scal.upcast(*[input.type.dtype for input in inputs])
# upcast inputs to common dtype if needed
upcasted_inputs = [cast(input, dtype) for input in inputs]
broadcastable = (inputs[0].type.broadcastable[:-1] +
broadcastable = ((inputs[0].type.broadcastable[0] or
inputs[1].type.broadcastable[0],) +
inputs[0].type.broadcastable[1:-1] +
inputs[1].type.broadcastable[2:])
return Apply(self, upcasted_inputs, [tensor(dtype, broadcastable)])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论