提交 a6253cba authored 作者: Frederic's avatar Frederic

Make gpu ger work with shape of 0 and 1.

上级 fc65d14a
......@@ -3005,10 +3005,22 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray
assert (CudaNdarray_HOST_STRIDES(y)[0] >= 0);
// Since Sger expects A in col-major, we invert x and y to fake this.
cublasSger(CudaNdarray_HOST_DIMS(y)[0], CudaNdarray_HOST_DIMS(x)[0], alpha,
CudaNdarray_DEV_DATA(y), CudaNdarray_HOST_STRIDES(y)[0],
CudaNdarray_DEV_DATA(x), CudaNdarray_HOST_STRIDES(x)[0],
CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_DIMS(A)[1]);
int x_strides = CudaNdarray_HOST_STRIDES(x)[0];
if(x_strides == 0){
assert(CudaNdarray_HOST_DIMS(x)[0] == 1);
x_strides = 4;
}
int y_strides = CudaNdarray_HOST_STRIDES(y)[0];
if(y_strides == 0){
assert(CudaNdarray_HOST_DIMS(y)[0] == 1);
y_strides = 4;
}
if(CudaNdarray_SIZE(A))
cublasSger(CudaNdarray_HOST_DIMS(y)[0], CudaNdarray_HOST_DIMS(x)[0], alpha,
CudaNdarray_DEV_DATA(y), y_strides,
CudaNdarray_DEV_DATA(x), x_strides,
CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_DIMS(A)[1]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论