提交 f6700ff8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix case in gpuger where the matrix is row/column

上级 8485ebd0
...@@ -3255,6 +3255,12 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y, ...@@ -3255,6 +3255,12 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
y_strides = CudaNdarray_HOST_STRIDES(y_)[0]; y_strides = CudaNdarray_HOST_STRIDES(y_)[0];
} }
// Create appropriate strides if A is a row or column vector
int sa_0 = (CudaNdarray_HOST_DIMS(A)[0] > 1) ? CudaNdarray_HOST_STRIDES(A)[0]
: CudaNdarray_HOST_DIMS(A)[1];
int sa_1 = (CudaNdarray_HOST_DIMS(A)[1] > 1) ? CudaNdarray_HOST_STRIDES(A)[1]
: CudaNdarray_HOST_DIMS(A)[0];
if(CudaNdarray_SIZE(A)){ if(CudaNdarray_SIZE(A)){
// If A is in col-major // If A is in col-major
if ((CudaNdarray_HOST_DIMS(A)[0] <= 1) if ((CudaNdarray_HOST_DIMS(A)[0] <= 1)
...@@ -3264,7 +3270,7 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y, ...@@ -3264,7 +3270,7 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
cublasSger(CudaNdarray_HOST_DIMS(x)[0], CudaNdarray_HOST_DIMS(y)[0], alpha, cublasSger(CudaNdarray_HOST_DIMS(x)[0], CudaNdarray_HOST_DIMS(y)[0], alpha,
CudaNdarray_DEV_DATA(x_), x_strides, CudaNdarray_DEV_DATA(x_), x_strides,
CudaNdarray_DEV_DATA(y_), y_strides, CudaNdarray_DEV_DATA(y_), y_strides,
CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_STRIDES(A)[1]); CudaNdarray_DEV_DATA(A), sa_1);
} }
// Since Sger expects A in col-major, we invert x and y to fake this. // Since Sger expects A in col-major, we invert x and y to fake this.
else if ((CudaNdarray_HOST_DIMS(A)[1] <= 1) else if ((CudaNdarray_HOST_DIMS(A)[1] <= 1)
...@@ -3274,7 +3280,7 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y, ...@@ -3274,7 +3280,7 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
cublasSger(CudaNdarray_HOST_DIMS(y)[0], CudaNdarray_HOST_DIMS(x)[0], alpha, cublasSger(CudaNdarray_HOST_DIMS(y)[0], CudaNdarray_HOST_DIMS(x)[0], alpha,
CudaNdarray_DEV_DATA(y_), y_strides, CudaNdarray_DEV_DATA(y_), y_strides,
CudaNdarray_DEV_DATA(x_), x_strides, CudaNdarray_DEV_DATA(x_), x_strides,
CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_STRIDES(A)[0]); CudaNdarray_DEV_DATA(A), sa_0);
} }
// A has to be either c- or f-contiguous, with no negative strides // A has to be either c- or f-contiguous, with no negative strides
else else
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论