Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e0cca017
提交
e0cca017
authored
9月 20, 2012
作者:
Ian Goodfellow
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
made GpuCAReduce use scalar_op.c_code instead of cuda_assign_reduce
上级
9411f9e8
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
53 行增加
和
24 行删除
+53
-24
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+53
-24
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
e0cca017
...
@@ -523,6 +523,9 @@ class GpuCAReduce(GpuOp):
...
@@ -523,6 +523,9 @@ class GpuCAReduce(GpuOp):
def
__init__
(
self
,
reduce_mask
,
scalar_op
):
def
__init__
(
self
,
reduce_mask
,
scalar_op
):
self
.
reduce_mask
=
tuple
(
reduce_mask
)
self
.
reduce_mask
=
tuple
(
reduce_mask
)
self
.
scalar_op
=
scalar_op
self
.
scalar_op
=
scalar_op
# used to make sure that calls to scalar op
# have unique name arguments
self
.
_n_scalar_op_calls
=
0
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
return
(
type
(
self
)
==
type
(
other
)
and
...
@@ -843,19 +846,38 @@ class GpuCAReduce(GpuOp):
...
@@ -843,19 +846,38 @@ class GpuCAReduce(GpuOp):
"""
"""
def
_assign_reduce
(
self
,
left
,
right
):
def
_assign_reduce
(
self
,
node
,
name
,
left
,
right
,
sub
):
"""
"""
node: the node argument to this op's c_code
name: the name argument to this op's c_code
left: a C code string identifying an lvalue
left: a C code string identifying an lvalue
right: a C code string identifying an expression
right: a C code string identifying an expression
sub: the sub argument to this op's c_code
returns C code to reduce left and right, assigning the
returns C code to reduce left and right, assigning the
result to left."""
result to left."""
return
self
.
scalar_op
.
cuda_assign_reduce
(
left
,
right
)
x
,
=
node
.
inputs
def
_k_reduce_buf
(
self
,
z_pos
):
dtype
=
x
.
dtype
dummy_left
=
scal
.
Scalar
(
dtype
=
dtype
)()
dummy_right
=
scal
.
Scalar
(
dtype
=
dtype
)()
dummy_node
=
self
.
scalar_op
.
make_node
(
dummy_left
,
dummy_right
)
dummy_name
=
name
+
'_scalar_op'
+
str
(
self
.
_n_scalar_op_calls
)
self
.
_n_scalar_op_calls
+=
1
return
self
.
scalar_op
.
c_code
(
node
,
name
,
(
left
,
right
),
(
left
,
),
sub
)
def
_k_reduce_buf
(
self
,
z_pos
,
node
,
name
,
sub
):
"""
"""
WRITEME
WRITEME
node, name, sub: these should be passed through from the original
call to c_code
"""
"""
# This code (the code in new_version) is currently ignored.
# This code (the code in new_version) is currently ignored.
...
@@ -872,7 +894,7 @@ class GpuCAReduce(GpuOp):
...
@@ -872,7 +894,7 @@ class GpuCAReduce(GpuOp):
{
{
int idx = threadNum - (threadCount >> 1) * 2;"""
int idx = threadNum - (threadCount >> 1) * 2;"""
new_version
+=
self
.
_assign_reduce
(
'buf[idx]'
,
'buf[threadNum]'
)
new_version
+=
self
.
_assign_reduce
(
node
,
name
,
'buf[idx]'
,
'buf[threadNum]'
,
sub
)
new_version
+=
"""
new_version
+=
"""
}
}
...
@@ -891,7 +913,7 @@ class GpuCAReduce(GpuOp):
...
@@ -891,7 +913,7 @@ class GpuCAReduce(GpuOp):
float temp = buf[threadNum + halfPoint];
float temp = buf[threadNum + halfPoint];
"""
"""
new_version
+=
self
.
_assign_reduce
(
'buf[threadNum]'
,
'temp'
)
new_version
+=
self
.
_assign_reduce
(
node
,
name
,
'buf[threadNum]'
,
'temp'
,
sub
)
new_version
+=
"""
new_version
+=
"""
}
}
...
@@ -921,7 +943,7 @@ class GpuCAReduce(GpuOp):
...
@@ -921,7 +943,7 @@ class GpuCAReduce(GpuOp):
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
{
"""
"""
current_version
+=
self
.
_assign_reduce
(
'myresult'
,
'buf[i]'
)
+
"""
current_version
+=
self
.
_assign_reduce
(
node
,
name
,
'myresult'
,
'buf[i]'
,
sub
)
+
"""
}
}
buf[threadNum] = myresult;
buf[threadNum] = myresult;
/*Comment this optimization as it don't work on Fermi GPU.
/*Comment this optimization as it don't work on Fermi GPU.
...
@@ -930,8 +952,8 @@ class GpuCAReduce(GpuOp):
...
@@ -930,8 +952,8 @@ class GpuCAReduce(GpuOp):
if(threadCount >32)
if(threadCount >32)
{"""
{"""
for
num
in
[
16
,
8
,
4
,
2
,
1
]:
for
num
in
[
16
,
8
,
4
,
2
,
1
]:
current_version
+=
self
.
_assign_reduce
(
'buf[threadNum]'
,
current_version
+=
self
.
_assign_reduce
(
node
,
name
,
'buf[threadNum]'
,
'buf[threadNum+
%
d]'
)
%
num
'buf[threadNum+
%
d]'
%
num
,
sub
)
current_version
+=
"""
current_version
+=
"""
if (threadNum == 0)
if (threadNum == 0)
{
{
...
@@ -945,9 +967,9 @@ class GpuCAReduce(GpuOp):
...
@@ -945,9 +967,9 @@ class GpuCAReduce(GpuOp):
//reduce so that threadNum 0 has the reduction of everything
//reduce so that threadNum 0 has the reduction of everything
"""
"""
for
num
in
[
16
,
8
,
4
,
2
,
1
]:
for
num
in
[
16
,
8
,
4
,
2
,
1
]:
this_if
=
"if (threadNum +
%
d < threadCount) "
+
\
this_if
=
"if (threadNum +
%
d < threadCount) "
%
num
+
\
self
.
_assign_reduce
(
'buf[threadNum]'
,
'buf[threadNum+
%
d]'
)
self
.
_assign_reduce
(
node
,
name
,
'buf[threadNum]'
,
'buf[threadNum+
%
d]'
%
num
,
sub
)
current_version
+=
this_if
%
(
num
,
num
)
current_version
+=
this_if
current_version
+=
"""
current_version
+=
"""
if (threadNum == 0)
if (threadNum == 0)
{
{
...
@@ -1528,7 +1550,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1528,7 +1550,7 @@ class GpuCAReduce(GpuOp):
self
.
_op_guard
()
self
.
_op_guard
()
#this kernel is ok for up to a few thousand elements, but
#this kernel is ok for up to a few thousand elements, but
# it only runs on ONE multiprocessor
# it only runs on ONE multiprocessor
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
,
node
,
nodename
,
sub
=
{}
)
print
>>
sio
,
"""
print
>>
sio
,
"""
static __global__ void kernel_reduce_ccontig_
%(nodename)
s(
static __global__ void kernel_reduce_ccontig_
%(nodename)
s(
const unsigned int d0,
const unsigned int d0,
...
@@ -1556,7 +1578,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1556,7 +1578,7 @@ class GpuCAReduce(GpuOp):
self
.
_op_guard
()
self
.
_op_guard
()
#this kernel is ok for up to a few thousand elements, but
#this kernel is ok for up to a few thousand elements, but
# it only runs on ONE multiprocessor
# it only runs on ONE multiprocessor
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
,
node
,
nodename
,
sub
=
{}
)
print
>>
sio
,
"""
print
>>
sio
,
"""
static __global__ void kernel_reduce_1_
%(nodename)
s(
static __global__ void kernel_reduce_1_
%(nodename)
s(
const unsigned int d0,
const unsigned int d0,
...
@@ -1585,7 +1607,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1585,7 +1607,7 @@ class GpuCAReduce(GpuOp):
self
.
_op_guard
()
self
.
_op_guard
()
#this kernel is ok for up to a few thousand elements, but
#this kernel is ok for up to a few thousand elements, but
# it only runs on ONE multiprocessor
# it only runs on ONE multiprocessor
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
,
node
,
nodename
,
sub
=
{}
)
print
>>
sio
,
"""
print
>>
sio
,
"""
static __global__ void kernel_reduce_11_
%(nodename)
s(
static __global__ void kernel_reduce_11_
%(nodename)
s(
const int d0,
const int d0,
...
@@ -1656,7 +1678,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1656,7 +1678,7 @@ class GpuCAReduce(GpuOp):
first_i3
=
'threadIdx.x'
first_i3
=
'threadIdx.x'
sA3
=
'sA3'
sA3
=
'sA3'
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0]'
,
node
,
nodename
,
sub
=
{}
)
param_dim
=
","
.
join
([
"const int d
%(i)
s"
%
locals
()
param_dim
=
","
.
join
([
"const int d
%(i)
s"
%
locals
()
for
i
in
xrange
(
nd_in
)])
for
i
in
xrange
(
nd_in
)])
param_strides
=
","
.
join
([
"const int sA
%(i)
s"
%
locals
()
param_strides
=
","
.
join
([
"const int sA
%(i)
s"
%
locals
()
...
@@ -1730,7 +1752,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1730,7 +1752,7 @@ class GpuCAReduce(GpuOp):
#TODO: This kernel is pretty inefficient in terms of reading, because if A is
#TODO: This kernel is pretty inefficient in terms of reading, because if A is
# c_contiguous (typical case) then each warp is accessing non-contigous
# c_contiguous (typical case) then each warp is accessing non-contigous
# memory (a segment of a column).
# memory (a segment of a column).
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i2*sZ1]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i2*sZ1]'
,
node
,
nodename
,
sub
=
{}
)
print
>>
sio
,
"""
print
>>
sio
,
"""
static __global__ void kernel_reduce_010_
%(nodename)
s(
static __global__ void kernel_reduce_010_
%(nodename)
s(
const int d0,
const int d0,
...
@@ -1856,7 +1878,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1856,7 +1878,7 @@ class GpuCAReduce(GpuOp):
#TODO: This kernel is pretty inefficient in terms of reading, because if A is
#TODO: This kernel is pretty inefficient in terms of reading, because if A is
# c_contiguous (typical case) then each warp is accessing non-contigous
# c_contiguous (typical case) then each warp is accessing non-contigous
# memory (a segment of a column).
# memory (a segment of a column).
reducebuf
=
self
.
_k_reduce_buf
(
'Z[blockIdx.x * sZ0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[blockIdx.x * sZ0]'
,
node
,
nodename
,
sub
=
{}
)
print
>>
sio
,
"""
print
>>
sio
,
"""
static __global__ void kernel_reduce_110_
%(nodename)
s(
static __global__ void kernel_reduce_110_
%(nodename)
s(
const int d0,
const int d0,
...
@@ -1892,7 +1914,8 @@ class GpuCAReduce(GpuOp):
...
@@ -1892,7 +1914,8 @@ class GpuCAReduce(GpuOp):
"""
%
locals
()
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
0
,
0
):
if
self
.
reduce_mask
==
(
1
,
0
,
0
):
self
.
_op_guard
()
self
.
_op_guard
()
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i1 * sZ0 + i2 * sZ1]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i1 * sZ0 + i2 * sZ1]'
,
node
,
nodename
,
sub
=
{})
decl
=
self
.
_k_decl
(
node
,
nodename
)
decl
=
self
.
_k_decl
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
print
>>
sio
,
"""
print
>>
sio
,
"""
...
@@ -1915,7 +1938,8 @@ class GpuCAReduce(GpuOp):
...
@@ -1915,7 +1938,8 @@ class GpuCAReduce(GpuOp):
"""
%
locals
()
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
1
,
1
):
if
self
.
reduce_mask
==
(
1
,
1
,
1
):
self
.
_op_guard
()
self
.
_op_guard
()
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
,
node
,
nodename
,
sub
=
{})
decl
=
self
.
_k_decl
(
node
,
nodename
)
decl
=
self
.
_k_decl
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
print
>>
sio
,
"""
print
>>
sio
,
"""
...
@@ -1940,7 +1964,8 @@ class GpuCAReduce(GpuOp):
...
@@ -1940,7 +1964,8 @@ class GpuCAReduce(GpuOp):
self
.
_op_guard
()
self
.
_op_guard
()
# this kernel uses one block for each row,
# this kernel uses one block for each row,
# threads per block for each element per row.
# threads per block for each element per row.
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i1 * sZ1]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i1 * sZ1]'
,
node
,
nodename
,
sub
=
{})
print
>>
sio
,
"""
print
>>
sio
,
"""
static __global__ void kernel_reduce_001_
%(nodename)
s(
static __global__ void kernel_reduce_001_
%(nodename)
s(
const int d0,
const int d0,
...
@@ -1977,7 +2002,8 @@ class GpuCAReduce(GpuOp):
...
@@ -1977,7 +2002,8 @@ class GpuCAReduce(GpuOp):
self
.
_op_guard
()
self
.
_op_guard
()
# this kernel uses one block for each row,
# this kernel uses one block for each row,
# threads per block for each element per row.
# threads per block for each element per row.
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i1 * sZ1]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i1 * sZ1]'
,
node
,
nodename
,
sub
=
{})
decl
=
self
.
_k_decl
(
node
,
nodename
)
decl
=
self
.
_k_decl
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
print
>>
sio
,
"""
print
>>
sio
,
"""
...
@@ -2006,7 +2032,8 @@ class GpuCAReduce(GpuOp):
...
@@ -2006,7 +2032,8 @@ class GpuCAReduce(GpuOp):
self
.
_op_guard
()
self
.
_op_guard
()
# this kernel uses one block for each row,
# this kernel uses one block for each row,
# threads per block for each element per row.
# threads per block for each element per row.
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i2 * sZ1]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i2 * sZ1]'
,
node
,
nodename
,
sub
=
{})
decl
=
self
.
_k_decl
(
node
,
nodename
)
decl
=
self
.
_k_decl
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
print
>>
sio
,
"""
print
>>
sio
,
"""
...
@@ -2033,7 +2060,8 @@ class GpuCAReduce(GpuOp):
...
@@ -2033,7 +2060,8 @@ class GpuCAReduce(GpuOp):
"""
%
locals
()
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
1
,
1
,
1
):
if
self
.
reduce_mask
==
(
1
,
1
,
1
,
1
):
self
.
_op_guard
()
self
.
_op_guard
()
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
,
node
,
nodename
,
sub
=
{})
decl
=
self
.
_k_decl
(
node
,
nodename
)
decl
=
self
.
_k_decl
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
print
>>
sio
,
"""
print
>>
sio
,
"""
...
@@ -2057,7 +2085,8 @@ class GpuCAReduce(GpuOp):
...
@@ -2057,7 +2085,8 @@ class GpuCAReduce(GpuOp):
"""
%
locals
()
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
0
,
1
,
1
):
if
self
.
reduce_mask
==
(
1
,
0
,
1
,
1
):
self
.
_op_guard
()
self
.
_op_guard
()
reducebuf
=
self
.
_k_reduce_buf
(
'Z[blockIdx.x*sZ0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[blockIdx.x*sZ0]'
,
node
,
nodename
,
sub
=
{})
print
>>
sio
,
"""
print
>>
sio
,
"""
static __global__ void kernel_reduce_1011_
%(nodename)
s(
static __global__ void kernel_reduce_1011_
%(nodename)
s(
const unsigned int d0,
const unsigned int d0,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论