提交 6fb0300f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4340 from lamblin/fix_broadcast

Add broadcastable dimensions in opt where needed
......@@ -2835,7 +2835,16 @@ def local_subtensor_merge(node):
# and stacktrace from previous slicing operation.
# Why? Because, the merged slicing operation could have failed
# because of either of the two original slicing operations
copy_stack_trace([node.outputs[0], node.inputs[0]], out)
orig_out = node.outputs[0]
copy_stack_trace([orig_out, node.inputs[0]], out)
# Restore original broadcastable dimensions that `subtens()` may
# have been unable to infer again
if out.type != orig_out.type:
assert out.dtype == orig_out.dtype
assert out.ndim == orig_out.ndim
out = T.patternbroadcast(out, orig_out.broadcastable)
copy_stack_trace([orig_out, node.inputs[0]], out)
return [out]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论