提交 4e1e00cc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add broadcastable dimensions in opt where needed

上级 3cea20e0
...@@ -2835,7 +2835,16 @@ def local_subtensor_merge(node): ...@@ -2835,7 +2835,16 @@ def local_subtensor_merge(node):
# and stacktrace from previous slicing operation. # and stacktrace from previous slicing operation.
# Why? Because, the merged slicing operation could have failed # Why? Because, the merged slicing operation could have failed
# because of either of the two original slicing operations # because of either of the two original slicing operations
copy_stack_trace([node.outputs[0], node.inputs[0]], out) orig_out = node.outputs[0]
copy_stack_trace([orig_out, node.inputs[0]], out)
# Restore original broadcastable dimensions that `subtens()` may
# have been unable to infer again
if out.type != orig_out.type:
assert out.dtype == orig_out.dtype
assert out.ndim == orig_out.ndim
out = T.patternbroadcast(out, orig_out.broadcastable)
copy_stack_trace([orig_out, node.inputs[0]], out)
return [out] return [out]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论