提交 6e5d160c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Treat length-0 dimensions as unit strides in gemm

上级 7a515454
......@@ -2990,24 +2990,26 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
// a stride of 0 implies a dimension of 1 - so we can actually define
// a stride of 0 as a 'unit' stride because gemm will never use it.
// If a dimension is 0, its stride will not be used either, so we can
// consider it a 'unit' stride too.
int unit = 0;
if (CudaNdarray_HOST_STRIDES(A)[1] == 1 || CudaNdarray_HOST_STRIDES(A)[1] == 0) {
if (CudaNdarray_HOST_STRIDES(A)[1] == 1 || CudaNdarray_HOST_DIMS(A)[1] <= 1) {
unit |= (0x0 << 8);
} else if (CudaNdarray_HOST_STRIDES(A)[0] == 1 || CudaNdarray_HOST_STRIDES(A)[0] == 0) {
} else if (CudaNdarray_HOST_STRIDES(A)[0] == 1 || CudaNdarray_HOST_DIMS(A)[0] <= 1) {
unit |= (0x1 << 8);
} else {
unit |= (0x2 << 8);
}
if (CudaNdarray_HOST_STRIDES(B)[1] == 1 || CudaNdarray_HOST_STRIDES(B)[1] == 0) {
if (CudaNdarray_HOST_STRIDES(B)[1] == 1 || CudaNdarray_HOST_DIMS(B)[1] <= 1) {
unit |= (0x0 << 4);
} else if (CudaNdarray_HOST_STRIDES(B)[0] == 1 || CudaNdarray_HOST_STRIDES(B)[0] == 0) {
} else if (CudaNdarray_HOST_STRIDES(B)[0] == 1 || CudaNdarray_HOST_DIMS(B)[0] <= 1) {
unit |= (0x1 << 4);
} else {
unit |= (0x2 << 4);
}
if (CudaNdarray_HOST_STRIDES(C)[1] == 1 || CudaNdarray_HOST_STRIDES(C)[1] == 0) {
if (CudaNdarray_HOST_STRIDES(C)[1] == 1 || CudaNdarray_HOST_DIMS(C)[1] <= 1) {
unit |= (0x0 << 0);
} else if (CudaNdarray_HOST_STRIDES(C)[0] == 1 || CudaNdarray_HOST_STRIDES(C)[0] == 0) {
} else if (CudaNdarray_HOST_STRIDES(C)[0] == 1 || CudaNdarray_HOST_DIMS(C)[0] <= 1) {
unit |= (0x1 << 0);
} else {
unit |= (0x2 << 0);
......@@ -3053,7 +3055,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
case 0x101: CHK_STRIDE_SGEMM(N, T, CudaNdarray_HOST_DIMS(C)[0], CudaNdarray_HOST_DIMS(C)[1], CudaNdarray_HOST_DIMS(A)[1], alpha, a, sa_1, b, sb_0, beta, c, sc_1); break;
case 0x011: CHK_STRIDE_SGEMM(T, N, CudaNdarray_HOST_DIMS(C)[0], CudaNdarray_HOST_DIMS(C)[1], CudaNdarray_HOST_DIMS(A)[1], alpha, a, sa_0, b, sb_1, beta, c, sc_1); break;
case 0x111: CHK_STRIDE_SGEMM(N, N, CudaNdarray_HOST_DIMS(C)[0], CudaNdarray_HOST_DIMS(C)[1], CudaNdarray_HOST_DIMS(A)[1], alpha, a, sa_1, b, sb_1, beta, c, sc_1); break;
default: PyErr_Format(PyExc_ValueError, "some matrix has no unit stride (unit=%i)", unit);
default: PyErr_Format(PyExc_ValueError, "some matrix has no unit stride (unit=%x)", unit);
return -1;
};
CNDA_THREAD_SYNC;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论