提交 418c8d8e authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Init magma once per gpu context

上级 e40ad6b7
......@@ -375,8 +375,18 @@ class GpuMagmaBase(COp):
return [config.magma.library_path]
return []
def prepare_node(self, node, storage_map, compute_map, impl):
from skcuda.magma import magma_init
ctx = node.inputs[0].type.context
if not getattr(ctx, 'is_magma_initialized', False):
magma_init()
ctx.is_magma_initialized = True
def get_params(self, node):
return node.inputs[0].type.context
class GpuMagmaSVD(COp):
class GpuMagmaSVD(GpuMagmaBase):
"""Computes the svd of a matrix :math:`A` using magma library.
.. warning::
......@@ -418,6 +428,7 @@ class GpuMagmaSVD(COp):
context_name=ctx_name)()])
def prepare_node(self, node, storage_map, compute_map, impl):
super(GpuMagmaSVD, self).prepare_node(node, storage_map, compute_map, impl)
# Check node to prevent eventual errors with old pickled nodes.
if self.compute_uv:
A, B, C = node.outputs
......
......@@ -65,7 +65,6 @@ int APPLY_SPECIFIC(magma_cholesky)(PyGpuArrayObject *A, PyGpuArrayObject **L,
// This is early to match the exit() in the fail label.
cuda_enter(c->ctx);
magma_init();
#ifdef INPLACE
Py_XDECREF(*L);
......@@ -125,7 +124,6 @@ int APPLY_SPECIFIC(magma_cholesky)(PyGpuArrayObject *A, PyGpuArrayObject **L,
#endif
res = 0;
fail:
magma_finalize();
cuda_exit(c->ctx);
return res;
}
......@@ -47,7 +47,6 @@ int APPLY_SPECIFIC(magma_eigh)(PyGpuArrayObject *A_,
// This is early to match the exit() in the fail label.
cuda_enter(c->ctx);
magma_init();
// magma matrix eigen decomposition of a symmetric matrix
N = PyGpuArray_DIM(A, 0);
......@@ -133,7 +132,6 @@ fail:
magma_free_pinned(work_data);
if (iwork_data != NULL)
magma_free_cpu(iwork_data);
magma_finalize();
cuda_exit(c->ctx);
return res;
}
......@@ -93,7 +93,6 @@ fail:
magma_free(piv);
if (dwork != NULL)
gpudata_release(dwork);
magma_finalize();
cuda_exit(params->context->ctx);
return res;
}
......@@ -67,7 +67,6 @@ int APPLY_SPECIFIC(magma_qr)(PyGpuArrayObject *A_,
// This is early to match the exit() in the fail label.
cuda_enter(c->ctx);
magma_init();
// magma matrix qr
M = PyGpuArray_DIM(A, 0);
......@@ -148,7 +147,6 @@ fail:
magma_free_pinned(tau_data);
if (work_data != NULL)
gpudata_release(work_data);
magma_finalize();
cuda_exit(c->ctx);
return res;
}
......@@ -166,7 +166,6 @@ fail:
magma_free_pinned(work);
if (iwork != NULL)
magma_free_cpu(iwork);
magma_finalize();
cuda_exit(params->context->ctx);
return res;
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论