提交 5ba4af1e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix confusion between strides in cnda and cublas

上级 f6700ff8
......@@ -3135,7 +3135,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
Py_INCREF(B);
}
// I don't know if cudablas handles negative strides
// cudablas does not handle negative strides as expected
if ( (CudaNdarray_HOST_STRIDES(A)[0] < 0)
|| (CudaNdarray_HOST_STRIDES(A)[1] < 0))
{
......@@ -3155,8 +3155,15 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
int sb_0 = (CudaNdarray_HOST_DIMS(B)[0] > 1) ? CudaNdarray_HOST_STRIDES(B)[0] : 1;
int sc_0 = (CudaNdarray_HOST_DIMS(C)[0] > 1) ? CudaNdarray_HOST_STRIDES(C)[0] : 1;
if (sa_0 == 0)
sa_0 = 1;
if (sa_1 == 0)
sa_1 = 1;
if (CudaNdarray_SIZE(C)) {
if ((sa_0 == 1) || (sa_0 == 0))
if ((CudaNdarray_HOST_DIMS(A)[0] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[0] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[1] > 0)))
{
cublasSgemv('N',
CudaNdarray_HOST_DIMS(A)[0], CudaNdarray_HOST_DIMS(A)[1],
......@@ -3166,7 +3173,9 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
beta,
CudaNdarray_DEV_DATA(C), sc_0);
}
else if ((sa_1 == 1) || (sa_1 == 0))
else if ((CudaNdarray_HOST_DIMS(A)[1] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[1] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[0] > 0)))
{
cublasSgemv('T',
CudaNdarray_HOST_DIMS(A)[1], CudaNdarray_HOST_DIMS(A)[0],
......@@ -3179,13 +3188,16 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
else
{
PyErr_Format(PyExc_AssertionError,
"Unexpected stride pattern in gemv: (%i, %i) x %i -> %i",
sa_0, sa_1, sb_0, sc_0);
printf("shapes are: (%i, %i) x %i -> %i\n",
CudaNdarray_HOST_DIMS(A)[0],
CudaNdarray_HOST_DIMS(A)[1],
CudaNdarray_HOST_DIMS(B)[0],
CudaNdarray_HOST_DIMS(C)[0]);
"Unexpected stride pattern in gemv: (%i, %i) x %i -> %i.\n"
"Shapes are: (%i, %i) x %i -> %i\n",
CudaNdarray_HOST_STRIDES(A)[0],
CudaNdarray_HOST_STRIDES(A)[1],
CudaNdarray_HOST_STRIDES(B)[0],
CudaNdarray_HOST_STRIDES(C)[0],
CudaNdarray_HOST_DIMS(A)[0],
CudaNdarray_HOST_DIMS(A)[1],
CudaNdarray_HOST_DIMS(B)[0],
CudaNdarray_HOST_DIMS(C)[0]);
Py_XDECREF(A);
Py_XDECREF(B);
return -1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论