Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a59d9a54
提交
a59d9a54
authored
1月 08, 2010
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
simplified code.
上级
adb994dd
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
6 行增加
和
75 行删除
+6
-75
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+6
-75
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
a59d9a54
...
...
@@ -1019,6 +1019,7 @@ class GpuSum(Op):
#TODO: This kernel is pretty inefficient in terms of reading, because if A is
# c_contiguous (typical case) then each warp is accessing non-contigous
# memory (a segment of a column).
reducebuf
=
self
.
_k_reduce_buf
(
'Z[blockIdx.x * sZ0]'
)
print
>>
sio
,
"""
static __global__ void kernel_reduce_sum_10_
%(nodename)
s(
const int d0,
...
...
@@ -1041,31 +1042,7 @@ class GpuSum(Op):
float Ai = A[i0 * sA0 + blockIdx.x * sA1];
mysum += Ai;
}
buf[threadNum] = mysum;
__syncthreads();
// rest of function is handled by one warp
if (threadNum < warpSize)
{
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
mysum += buf[i];
}
buf[threadNum] = mysum;
if (threadNum < 16)
{
//reduce so that threadNum 0 has the sum of everything
if(threadNum + 16 < threadCount) buf[threadNum] += buf[threadNum+16];
if(threadNum + 8 < threadCount) buf[threadNum] += buf[threadNum+8];
if(threadNum + 4 < threadCount) buf[threadNum] += buf[threadNum+4];
if(threadNum + 2 < threadCount) buf[threadNum] += buf[threadNum+2];
if(threadNum + 1 < threadCount) buf[threadNum] += buf[threadNum+1];
if (threadNum == 0)
{
Z[blockIdx.x * sZ0] = buf[0];
}
}
}
%(reducebuf)
s
}
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
1
,
0
):
...
...
@@ -1155,6 +1132,7 @@ class GpuSum(Op):
if
self
.
reduce_mask
==
(
0
,
0
,
1
):
# this kernel uses one block for each row,
# threads per block for each element per row.
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i0 * sZ0 + i1 * sZ1]'
)
print
>>
sio
,
"""
static __global__ void kernel_reduce_sum_001_
%(nodename)
s(
const int d0,
...
...
@@ -1181,31 +1159,7 @@ class GpuSum(Op):
{
mysum += A[i0 * sA0 + i1 * sA1 + i2 * sA2];
}
buf[threadNum] = mysum;
__syncthreads();
// rest of function is handled by one warp
if (threadNum < warpSize)
{
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
mysum += buf[i];
}
buf[threadNum] = mysum;
if (threadNum < 16)
{
//reduce so that threadNum 0 has the sum of everything
if(threadNum + 16 < threadCount) buf[threadNum] += buf[threadNum+16];
if(threadNum + 8 < threadCount) buf[threadNum] += buf[threadNum+8];
if(threadNum + 4 < threadCount) buf[threadNum] += buf[threadNum+4];
if(threadNum + 2 < threadCount) buf[threadNum] += buf[threadNum+2];
if(threadNum + 1 < threadCount) buf[threadNum] += buf[threadNum+1];
if (threadNum == 0)
{
Z[i0 * sZ0 + i1 * sZ1] = buf[0];
}
}
}
%(reducebuf)
s
}
}
}
...
...
@@ -1234,6 +1188,7 @@ class GpuSum(Op):
}
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
0
,
1
,
1
):
reducebuf
=
self
.
_k_reduce_buf
(
'Z[blockIdx.x*sZ0]'
)
print
>>
sio
,
"""
static __global__ void kernel_reduce_sum_1011_
%(nodename)
s(
const unsigned int d0,
...
...
@@ -1264,31 +1219,7 @@ class GpuSum(Op):
}
}
}
buf[threadNum] = mysum;
__syncthreads();
// rest of function is handled by one warp
if (threadNum < warpSize)
{
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
mysum += buf[i];
}
buf[threadNum] = mysum;
if (threadNum < 16)
{
//reduce so that threadNum 0 has the sum of everything
if(threadNum + 16 < threadCount) buf[threadNum] += buf[threadNum+16];
if(threadNum + 8 < threadCount) buf[threadNum] += buf[threadNum+8];
if(threadNum + 4 < threadCount) buf[threadNum] += buf[threadNum+4];
if(threadNum + 2 < threadCount) buf[threadNum] += buf[threadNum+2];
if(threadNum + 1 < threadCount) buf[threadNum] += buf[threadNum+1];
if (threadNum == 0)
{
Z[blockIdx.x*sZ0] = buf[0];
}
}
}
%(reducebuf)
s
}
"""
%
locals
()
return
sio
.
getvalue
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论