提交 4baf3ece authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix the opt crash regression introduced by gh-5950.

上级 1029066e
...@@ -1213,10 +1213,10 @@ def local_gpua_gemmbatch(op, context_name, inputs, outputs): ...@@ -1213,10 +1213,10 @@ def local_gpua_gemmbatch(op, context_name, inputs, outputs):
# them from outputs # them from outputs
output_dims = [0, 1, 2] output_dims = [0, 1, 2]
if a.ndim == 2: if a.ndim == 2:
a = GpuDimShuffle(a.broadcastable, (0, 1, 'x'))(a) a = GpuDimShuffle(a.broadcastable, (0, 'x', 1))(a)
del output_dims[1] del output_dims[1]
if b.ndim == 2: if b.ndim == 2:
b = GpuDimShuffle(b.broadcastable, (0, 'x', 1))(b) b = GpuDimShuffle(b.broadcastable, (0, 1, 'x'))(b)
del output_dims[-1] del output_dims[-1]
# In case of mismatched dtypes, we also have to upcast # In case of mismatched dtypes, we also have to upcast
out_dtype = outputs[0].dtype out_dtype = outputs[0].dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论