提交 60c82e69 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Do not call gemm with 0 stride, some blas refuse.

上级 92248832
......@@ -478,8 +478,9 @@ class GemmRelated(Op):
(long int)Ny[1], (long int)Nz[1]);
%(fail)s;
}
// We must not raise an error when Nx[1] == 0. This would disable case
// that numpy.dot accept.
// We must not raise an error when Nx[1] == 0. This would disable cases
// that numpy.dot accept.
"""
check_strides = """
......@@ -526,14 +527,18 @@ class GemmRelated(Op):
compute_strides = """
/* create appropriate strides for malformed matrices that are row or column
* vectors
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : (Nx[1] + 1);
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[0] + 1);
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : (Ny[1] + 1);
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[0] + 1);
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : (Nz[1] + 1);
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[0] + 1);
"""
begin_switch_typenum = """
......@@ -639,7 +644,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '')
def build_gemm_version(self):
return (8,)
return (9,)
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论