提交 6174405c authored 作者: Frederic's avatar Frederic

make CudaNdarray_sger work with neg strides.

上级 e12965aa
...@@ -2999,29 +2999,38 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray ...@@ -2999,29 +2999,38 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray
PyErr_SetString(PyExc_NotImplementedError, "non-c continugous A in sger"); PyErr_SetString(PyExc_NotImplementedError, "non-c continugous A in sger");
return -1; return -1;
} }
// Same for this, be safe
assert (CudaNdarray_HOST_STRIDES(x)[0] >= 0);
assert (CudaNdarray_HOST_STRIDES(y)[0] >= 0);
// Since Sger expects A in col-major, we invert x and y to fake this. // Since Sger expects A in col-major, we invert x and y to fake this.
int x_strides = CudaNdarray_HOST_STRIDES(x)[0]; int x_strides = CudaNdarray_HOST_STRIDES(x)[0];
CudaNdarray * x_ = x;
if(x_strides == 0){ if(x_strides == 0){
assert(CudaNdarray_HOST_DIMS(x)[0] == 1); assert(CudaNdarray_HOST_DIMS(x)[0] == 1);
x_strides = 4; x_strides = 4;
} else if(x_strides < 0){
x_ = (CudaNdarray*)CudaNdarray_Copy(x);
x_strides = CudaNdarray_HOST_STRIDES(x_)[0];
} }
int y_strides = CudaNdarray_HOST_STRIDES(y)[0]; int y_strides = CudaNdarray_HOST_STRIDES(y)[0];
CudaNdarray * y_ = y;
if(y_strides == 0){ if(y_strides == 0){
assert(CudaNdarray_HOST_DIMS(y)[0] == 1); assert(CudaNdarray_HOST_DIMS(y)[0] == 1);
y_strides = 4; y_strides = 4;
} else if(y_strides < 0){
y_ = (CudaNdarray*)CudaNdarray_Copy(y);
y_strides = CudaNdarray_HOST_STRIDES(y_)[0];
} }
if(CudaNdarray_SIZE(A)) if(CudaNdarray_SIZE(A)){
cublasSger(CudaNdarray_HOST_DIMS(y)[0], CudaNdarray_HOST_DIMS(x)[0], alpha, cublasSger(CudaNdarray_HOST_DIMS(y)[0], CudaNdarray_HOST_DIMS(x)[0], alpha,
CudaNdarray_DEV_DATA(y), y_strides, CudaNdarray_DEV_DATA(y_), y_strides,
CudaNdarray_DEV_DATA(x), x_strides, CudaNdarray_DEV_DATA(x_), x_strides,
CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_DIMS(A)[1]); CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_DIMS(A)[1]);
}
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
if(x_ != x)
Py_DECREF(x_);
if(y_ != y)
Py_DECREF(y_);
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (CUBLAS_STATUS_SUCCESS != err) if (CUBLAS_STATUS_SUCCESS != err)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论