提交 5ba4af1e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix confusion between strides in cnda and cublas

上级 f6700ff8
...@@ -3135,7 +3135,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -3135,7 +3135,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
Py_INCREF(B); Py_INCREF(B);
} }
// I don't know if cudablas handles negative strides // cudablas does not handle negative strides as expected
if ( (CudaNdarray_HOST_STRIDES(A)[0] < 0) if ( (CudaNdarray_HOST_STRIDES(A)[0] < 0)
|| (CudaNdarray_HOST_STRIDES(A)[1] < 0)) || (CudaNdarray_HOST_STRIDES(A)[1] < 0))
{ {
...@@ -3155,8 +3155,15 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -3155,8 +3155,15 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
int sb_0 = (CudaNdarray_HOST_DIMS(B)[0] > 1) ? CudaNdarray_HOST_STRIDES(B)[0] : 1; int sb_0 = (CudaNdarray_HOST_DIMS(B)[0] > 1) ? CudaNdarray_HOST_STRIDES(B)[0] : 1;
int sc_0 = (CudaNdarray_HOST_DIMS(C)[0] > 1) ? CudaNdarray_HOST_STRIDES(C)[0] : 1; int sc_0 = (CudaNdarray_HOST_DIMS(C)[0] > 1) ? CudaNdarray_HOST_STRIDES(C)[0] : 1;
if (sa_0 == 0)
sa_0 = 1;
if (sa_1 == 0)
sa_1 = 1;
if (CudaNdarray_SIZE(C)) { if (CudaNdarray_SIZE(C)) {
if ((sa_0 == 1) || (sa_0 == 0)) if ((CudaNdarray_HOST_DIMS(A)[0] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[0] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[1] > 0)))
{ {
cublasSgemv('N', cublasSgemv('N',
CudaNdarray_HOST_DIMS(A)[0], CudaNdarray_HOST_DIMS(A)[1], CudaNdarray_HOST_DIMS(A)[0], CudaNdarray_HOST_DIMS(A)[1],
...@@ -3166,7 +3173,9 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -3166,7 +3173,9 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
beta, beta,
CudaNdarray_DEV_DATA(C), sc_0); CudaNdarray_DEV_DATA(C), sc_0);
} }
else if ((sa_1 == 1) || (sa_1 == 0)) else if ((CudaNdarray_HOST_DIMS(A)[1] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[1] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[0] > 0)))
{ {
cublasSgemv('T', cublasSgemv('T',
CudaNdarray_HOST_DIMS(A)[1], CudaNdarray_HOST_DIMS(A)[0], CudaNdarray_HOST_DIMS(A)[1], CudaNdarray_HOST_DIMS(A)[0],
...@@ -3179,13 +3188,16 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -3179,13 +3188,16 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
else else
{ {
PyErr_Format(PyExc_AssertionError, PyErr_Format(PyExc_AssertionError,
"Unexpected stride pattern in gemv: (%i, %i) x %i -> %i", "Unexpected stride pattern in gemv: (%i, %i) x %i -> %i.\n"
sa_0, sa_1, sb_0, sc_0); "Shapes are: (%i, %i) x %i -> %i\n",
printf("shapes are: (%i, %i) x %i -> %i\n", CudaNdarray_HOST_STRIDES(A)[0],
CudaNdarray_HOST_DIMS(A)[0], CudaNdarray_HOST_STRIDES(A)[1],
CudaNdarray_HOST_DIMS(A)[1], CudaNdarray_HOST_STRIDES(B)[0],
CudaNdarray_HOST_DIMS(B)[0], CudaNdarray_HOST_STRIDES(C)[0],
CudaNdarray_HOST_DIMS(C)[0]); CudaNdarray_HOST_DIMS(A)[0],
CudaNdarray_HOST_DIMS(A)[1],
CudaNdarray_HOST_DIMS(B)[0],
CudaNdarray_HOST_DIMS(C)[0]);
Py_XDECREF(A); Py_XDECREF(A);
Py_XDECREF(B); Py_XDECREF(B);
return -1; return -1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论