提交 9d413f18 authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Use divide and conquer algorithm to compute svd (faster)

上级 5f1e372d
......@@ -13,8 +13,9 @@ int APPLY_SPECIFIC(magma_svd)(PyGpuArrayObject *A,
PyGpuArrayObject **VT,
#endif
PyGpuContextObject *c) {
magma_int_t *iwork = NULL, iunused[1];
magma_int_t M, N, K, ldu, ldv, M_U, N_VT, info;
magma_vec_t jobu, jobv;
magma_vec_t jobz;
size_t s_dims[1], u_dims[2], vt_dims[2];
float *a_data = NULL, *s_data = NULL, *u_data = NULL, *vt_data = NULL,
*work = NULL;
......@@ -64,14 +65,12 @@ int APPLY_SPECIFIC(magma_svd)(PyGpuArrayObject *A,
#ifdef COMPUTE_UV
#ifdef FULL_MATRICES
jobu = MagmaAllVec;
jobv = MagmaAllVec;
jobz = MagmaAllVec;
#else
jobu = MagmaSomeVec;
jobv = MagmaSomeVec;
jobz = MagmaSomeVec;
#endif
M_U = (jobu == MagmaAllVec ? M : K);
N_VT = (jobv == MagmaAllVec ? N : K);
M_U = (jobz == MagmaAllVec ? M : K);
N_VT = (jobz == MagmaAllVec ? N : K);
ldu = M;
ldv = N_VT;
......@@ -86,15 +85,14 @@ int APPLY_SPECIFIC(magma_svd)(PyGpuArrayObject *A,
goto fail;
}
#else
jobu = MagmaNoVec;
jobv = MagmaNoVec;
jobz = MagmaNoVec;
ldu = M;
ldv = N;
#endif
// query for workspace size
magma_sgesvd(jobu, jobv, M, N, NULL, M, NULL, NULL, ldu, NULL, ldv,
dummy, -1, &info);
magma_sgesdd(jobz, M, N, NULL, M, NULL, NULL, ldu, NULL, ldv,
dummy, -1, iunused, &info);
lwork = (magma_int_t) MAGMA_S_REAL(dummy[0]);
if (MAGMA_SUCCESS != magma_smalloc_pinned(&work, lwork)) {
......@@ -103,13 +101,19 @@ int APPLY_SPECIFIC(magma_svd)(PyGpuArrayObject *A,
goto fail;
}
if (MAGMA_SUCCESS != magma_imalloc_cpu(&iwork, 8*K)) {
PyErr_SetString(PyExc_RuntimeError,
"GpuMagmaSVD: failed to allocate working memory");
goto fail;
}
// compute svd
magma_sgesvd(jobu, jobv, M, N, a_data, M, s_data,
u_data, ldu, vt_data, ldv, work, lwork, &info);
magma_sgesdd(jobz, M, N, a_data, M, s_data,
u_data, ldu, vt_data, ldv, work, lwork, iwork, &info);
if (info > 0) {
PyErr_Format(
PyExc_RuntimeError,
"GpuMagmaSVD: magma_sgesvd_gpu %d superdiagonals failed to converge",
"GpuMagmaSVD: the updating process of SBDSDC did not converge (error: %s).",
info);
goto fail;
} else if (info < 0) {
......@@ -163,6 +167,8 @@ fail:
magma_free_pinned(vt_data);
if (work != NULL)
magma_free_pinned(work);
if (iwork != NULL)
magma_free_cpu(iwork);
magma_finalize();
cuda_exit(c->ctx);
return res;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论