Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e8b2f369
提交
e8b2f369
authored
5月 03, 2017
作者:
Alexander Matyasko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add gpu magma qr decomposition
上级
1f9cc65c
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
211 行增加
和
0 行删除
+211
-0
linalg.py
theano/gpuarray/linalg.py
+55
-0
magma_qr.c
theano/gpuarray/magma_qr.c
+156
-0
没有找到文件。
theano/gpuarray/linalg.py
浏览文件 @
e8b2f369
...
@@ -581,3 +581,58 @@ class GpuMagmaCholesky(CGpuKernelBase):
...
@@ -581,3 +581,58 @@ class GpuMagmaCholesky(CGpuKernelBase):
def
infer_shape
(
self
,
node
,
shapes
):
def
infer_shape
(
self
,
node
,
shapes
):
return
[
shapes
[
0
]]
return
[
shapes
[
0
]]
class
GpuMagmaQR
(
CGpuKernelBase
):
"""Computes the qr decomposition of a matrix :math:`A` using magma
library.
Parameters
----------
complete : If `False`, returns only r.
"""
params_type
=
gpu_context_type
__props__
=
(
'complete'
,)
def
__init__
(
self
,
complete
=
True
):
self
.
complete
=
complete
COp
.
__init__
(
self
,
[
'magma_qr.c'
],
'APPLY_SPECIFIC(magma_qr)'
)
def
c_headers
(
self
):
return
[
'gpuarray/types.h'
,
'gpuarray/array.h'
,
'gpuarray/ext_cuda.h'
,
'gpuarray_helper.h'
,
'magma.h'
]
def
c_header_dirs
(
self
):
dirs
=
[
os
.
path
.
dirname
(
__file__
),
pygpu
.
get_include
()]
if
config
.
magma
.
include_path
:
dirs
.
append
(
config
.
magma
.
include_path
)
return
dirs
def
c_libraries
(
self
):
return
[
'magma'
]
def
c_lib_dirs
(
self
):
if
config
.
magma
.
library_path
:
return
[
config
.
magma
.
library_path
]
return
[]
def
make_node
(
self
,
A
):
ctx_name
=
infer_context_name
(
A
)
A
=
as_gpuarray_variable
(
A
,
ctx_name
)
A
=
gpu_contiguous
(
A
)
if
A
.
ndim
!=
2
:
raise
LinAlgError
(
"Matrix rank error"
)
if
self
.
complete
:
return
theano
.
Apply
(
self
,
[
A
],
[
A
.
type
(),
A
.
type
()])
else
:
return
theano
.
Apply
(
self
,
[
A
],
[
A
.
type
()])
def
get_params
(
self
,
node
):
return
node
.
inputs
[
0
]
.
type
.
context
def
get_op_params
(
self
):
params
=
[]
if
self
.
complete
:
params
.
append
((
'COMPLETE'
,
'1'
))
return
params
theano/gpuarray/magma_qr.c
0 → 100644
浏览文件 @
e8b2f369
#section kernels
#kernel triu_kernel : size, size, *:
KERNEL
void
triu_kernel
(
const
ga_size
nthreads
,
const
ga_size
ncols
,
GLOBAL_MEM
DTYPE_INPUT_0
*
a
)
{
// grid stride looping
for
(
ga_size
index
=
GID_0
*
LDIM_0
+
LID_0
;
index
<
nthreads
;
index
+=
LDIM_0
*
GDIM_0
)
{
const
ga_size
ix
=
index
/
ncols
;
const
ga_size
iy
=
index
%
ncols
;
if
(
ix
>
iy
)
{
a
[
index
]
=
0
.
0
;
}
}
}
#section init_code
setup_ext_cuda
();
#section support_code
static
PyGpuArrayObject
*
pygpu_narrow
(
PyGpuArrayObject
*
src
,
size_t
dim
,
size_t
size
)
{
PyGpuArrayObject
*
src_view
=
pygpu_view
(
src
,
Py_None
);
src_view
->
ga
.
dimensions
[
dim
]
=
size
;
return
pygpu_copy
(
src_view
,
GA_C_ORDER
);
}
#section support_code_struct
int
APPLY_SPECIFIC
(
magma_qr
)(
PyGpuArrayObject
*
A_
,
#ifdef COMPLETE
PyGpuArrayObject
**
Q
,
#endif
PyGpuArrayObject
**
R
,
PyGpuContextObject
*
c
)
{
PyGpuArrayObject
*
A
=
NULL
;
magma_int_t
M
,
N
,
K
,
nb
,
ldwork
;
size_t
n2
;
float
*
tau_data
=
NULL
;
gpudata
*
work_data
=
NULL
;
int
res
=
-
1
,
info
;
A
=
A_
;
if
(
A
->
ga
.
typecode
!=
GA_FLOAT
)
{
PyErr_SetString
(
PyExc_TypeError
,
"GpuMagmaQR: Unsupported data type"
);
return
-
1
;
}
if
(
!
GpuArray_IS_C_CONTIGUOUS
(
&
A
->
ga
))
{
PyErr_SetString
(
PyExc_ValueError
,
"GpuMagmaQR: requires data to be C-contiguous"
);
return
-
1
;
}
if
(
PyGpuArray_NDIM
(
A
)
!=
2
)
{
PyErr_SetString
(
PyExc_ValueError
,
"GpuMagmaQR: matrix rank error"
);
return
-
1
;
}
A
=
pygpu_copy
(
A_
,
GA_F_ORDER
);
if
(
A
==
NULL
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMagmaQR: failed to change to column-major order"
);
return
-
1
;
}
// This is early to match the exit() in the fail label.
cuda_enter
(
c
->
ctx
);
magma_init
();
// magma matrix qr
M
=
PyGpuArray_DIM
(
A
,
0
);
N
=
PyGpuArray_DIM
(
A
,
1
);
K
=
std
::
min
(
M
,
N
);
if
(
MAGMA_SUCCESS
!=
magma_smalloc_pinned
(
&
tau_data
,
N
*
N
))
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMagmaQR: failed to allocate working memory"
);
goto
fail
;
}
nb
=
magma_get_sgeqrf_nb
(
M
,
N
);
ldwork
=
(
2
*
K
+
magma_roundup
(
N
,
32
))
*
nb
;
work_data
=
gpudata_alloc
(
c
->
ctx
,
ldwork
*
sizeof
(
float
),
NULL
,
0
,
NULL
);
if
(
work_data
==
NULL
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMagmaQR: failed to allocate working memory"
);
goto
fail
;
}
// compute R
magma_sgeqrf2_gpu
(
M
,
N
,
(
float
*
)
PyGpuArray_DEV_DATA
(
A
),
M
,
tau_data
,
&
info
);
if
(
info
!=
0
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMagmaQR: magma_sgeqrf2_gpu argument %d has an illegal value"
,
-
info
);
goto
fail
;
}
*
R
=
pygpu_narrow
(
A
,
0
,
K
);
if
(
*
R
==
NULL
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMagmaQR: failed to narrow array"
);
goto
fail
;
}
n2
=
K
*
N
;
res
=
triu_kernel_scall
(
1
,
&
n2
,
0
,
n2
,
N
,
(
*
R
)
->
ga
.
data
);
if
(
res
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMagmaQR: triu_kernel %s."
,
GpuKernel_error
(
&
k_triu_kernel
,
res
));
goto
fail
;
}
#ifdef COMPLETE
// compute Q
Py_XDECREF
(
A
);
A
=
pygpu_copy
(
A_
,
GA_F_ORDER
);
if
(
A
==
NULL
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMagmaQR: failed to change to column-major order"
);
return
-
1
;
}
magma_sgeqrf_gpu
(
M
,
N
,
(
float
*
)
PyGpuArray_DEV_DATA
(
A
),
M
,
tau_data
,
*
(
float
**
)
work_data
,
&
info
);
if
(
info
!=
0
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMagmaQR: magma_sgeqrf_gpu argument %d has an illegal value"
,
-
info
);
goto
fail
;
}
magma_sorgqr_gpu
(
M
,
K
,
K
,
(
float
*
)
PyGpuArray_DEV_DATA
(
A
),
M
,
tau_data
,
*
(
float
**
)
work_data
,
nb
,
&
info
);
if
(
info
!=
0
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuMagmaQR: magma_sorgqr_gpu argument %d has an illegal value"
,
-
info
);
goto
fail
;
}
*
Q
=
pygpu_narrow
(
A
,
1
,
K
);
if
(
*
Q
==
NULL
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"GpuMagmaQR: failed to narrow array"
);
goto
fail
;
}
#endif
res
=
0
;
fail:
if
(
tau_data
!=
NULL
)
magma_free_pinned
(
tau_data
);
if
(
work_data
!=
NULL
)
gpudata_release
(
work_data
);
magma_finalize
();
cuda_exit
(
c
->
ctx
);
return
res
;
}
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论