Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6c3d8b43
提交
6c3d8b43
authored
4月 13, 2017
作者:
Frédéric Bastien
提交者:
GitHub
4月 13, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5789 from shawntan/issue-2763
Changes to GpuEye to enable `k` offset parameter.
上级
bc897190
dd9e38e1
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
54 行增加
和
23 行删除
+54
-23
basic_ops.py
theano/gpuarray/basic_ops.py
+30
-16
test_basic_ops.py
theano/gpuarray/tests/test_basic_ops.py
+24
-7
没有找到文件。
theano/gpuarray/basic_ops.py
浏览文件 @
6c3d8b43
...
...
@@ -1583,9 +1583,7 @@ class GpuEye(GpuKernelBase, Op):
broadcastable
=
(
False
,
False
),
context_name
=
self
.
context_name
)
# k != 0 isn't implemented on the GPU yet.
assert
tensor
.
get_scalar_constant_value
(
k
)
==
0
return
Apply
(
self
,
[
n
,
m
],
[
otype
()])
return
Apply
(
self
,
[
n
,
m
,
k
],
[
otype
()])
def
infer_shape
(
self
,
node
,
in_shapes
):
out_shape
=
[
node
.
inputs
[
0
],
node
.
inputs
[
1
]]
...
...
@@ -1597,21 +1595,28 @@ class GpuEye(GpuKernelBase, Op):
def
gpu_kernels
(
self
,
node
,
name
):
code
=
"""
KERNEL void eye(GLOBAL_MEM
%(ctype)
s *a, ga_size n, ga_size m) {
ga_size nb = n < m ? n : m;
KERNEL void eye(GLOBAL_MEM
%(ctype)
s *a, ga_size n, ga_size m, ga_ssize k) {
ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize roff = -min(k, (ga_ssize) 0);
ga_size nb = (ga_size) min(n - roff, m - coff);
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[
i*m + i
] =
%(write_a)
s(1);
a[
(i + roff)*m + i + coff
] =
%(write_a)
s(1);
}
}"""
%
dict
(
ctype
=
pygpu
.
gpuarray
.
dtype_to_ctype
(
self
.
dtype
),
name
=
name
,
write_a
=
write_w
(
self
.
dtype
))
return
[
Kernel
(
code
=
code
,
name
=
"eye"
,
params
=
[
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SIZE
],
params
=
[
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
],
flags
=
Kernel
.
get_flags
(
self
.
dtype
),
objvar
=
'k_eye_'
+
name
)]
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
n
,
m
=
inp
if
len
(
inp
)
==
2
:
n
,
m
=
inp
k
=
0
elif
len
(
inp
)
==
3
:
n
,
m
,
k
=
inp
z
,
=
out
fail
=
sub
[
'fail'
]
ctx
=
sub
[
'params'
]
...
...
@@ -1621,10 +1626,15 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
s
=
"""
size_t dims[2] = {0, 0};
size_t ls, gs;
ssize_t k;
size_t col_off;
size_t row_off;
int err;
dims[0] = ((dtype_
%(n)
s*)PyArray_DATA(
%(n)
s))[0];
dims[1] = ((dtype_
%(m)
s*)PyArray_DATA(
%(m)
s))[0];
k = ((dtype_
%(k)
s*)PyArray_DATA(
%(k)
s))[0];
Py_CLEAR(
%(z)
s);
%(z)
s = pygpu_zeros(2, dims,
...
...
@@ -1637,13 +1647,17 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ls = 1;
gs = 256;
err = eye_call(1, &gs, &ls, 0,
%(z)
s->ga.data, dims[0], dims[1]);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye:
%%
s. n
%%
lu, m=
%%
lu.",
GpuKernel_error(&
%(kname)
s, err),
(unsigned long)dims[0], (unsigned long)dims[1]);
%(fail)
s;
col_off = (size_t) (k > 0?k:0);
row_off = (size_t) (k < 0?-k:0);
if (row_off < dims[0] && col_off < dims[1]) {
err = eye_call(1, &gs, &ls, 0,
%(z)
s->ga.data, dims[0], dims[1], k);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye:
%%
s. n
%%
lu, m=
%%
lu.",
GpuKernel_error(&
%(kname)
s, err),
(unsigned long)dims[0], (unsigned long)dims[1]);
%(fail)
s;
}
}
if(
%(sync)
d)
...
...
@@ -1653,4 +1667,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
return
s
def
c_code_cache_version
(
self
):
return
(
6
,)
return
(
7
,)
theano/gpuarray/tests/test_basic_ops.py
浏览文件 @
6c3d8b43
...
...
@@ -392,7 +392,7 @@ def test_gpujoin_gpualloc():
def
test_gpueye
():
def
check
(
dtype
,
N
,
M_
=
None
):
def
check
(
dtype
,
N
,
M_
=
None
,
k
=
0
):
# Theano does not accept None as a tensor.
# So we must use a real value.
M
=
M_
...
...
@@ -402,13 +402,14 @@ def test_gpueye():
M
=
N
N_symb
=
T
.
iscalar
()
M_symb
=
T
.
iscalar
()
k_symb
=
np
.
asarray
(
0
)
out
=
T
.
eye
(
N_symb
,
M_symb
,
k_symb
,
dtype
=
dtype
)
f
=
theano
.
function
([
N_symb
,
M_symb
],
T
.
stack
(
out
)
,
k_symb
=
T
.
iscalar
(
)
out
=
T
.
eye
(
N_symb
,
M_symb
,
k_symb
,
dtype
=
dtype
)
+
np
.
array
(
1
)
.
astype
(
dtype
)
f
=
theano
.
function
([
N_symb
,
M_symb
,
k_symb
],
out
,
mode
=
mode_with_gpu
)
result
=
np
.
asarray
(
f
(
N
,
M
))
assert
np
.
allclose
(
result
,
np
.
eye
(
N
,
M_
,
dtype
=
dtype
))
result
=
np
.
asarray
(
f
(
N
,
M
,
k
))
-
np
.
array
(
1
)
.
astype
(
dtype
)
assert
np
.
allclose
(
result
,
np
.
eye
(
N
,
M_
,
k
,
dtype
=
dtype
))
assert
result
.
dtype
==
np
.
dtype
(
dtype
)
assert
any
([
isinstance
(
node
.
op
,
GpuEye
)
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
...
...
@@ -418,6 +419,22 @@ def test_gpueye():
# M != N, k = 0
yield
check
,
dtype
,
3
,
5
yield
check
,
dtype
,
5
,
3
# N == M, k != 0
yield
check
,
dtype
,
3
,
3
,
1
yield
check
,
dtype
,
3
,
3
,
-
1
# N < M, k != 0
yield
check
,
dtype
,
3
,
5
,
1
yield
check
,
dtype
,
3
,
5
,
-
1
# N > M, k != 0
yield
check
,
dtype
,
5
,
3
,
1
yield
check
,
dtype
,
5
,
3
,
-
1
# k > M, -k > N, k > M, k > N
yield
check
,
dtype
,
5
,
3
,
3
yield
check
,
dtype
,
3
,
5
,
3
yield
check
,
dtype
,
5
,
3
,
-
3
yield
check
,
dtype
,
3
,
5
,
-
3
yield
check
,
dtype
,
5
,
3
,
6
yield
check
,
dtype
,
3
,
5
,
-
6
def
test_hostfromgpu_shape_i
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论