Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
609593b4
提交
609593b4
authored
5月 31, 2010
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix GpuSum pattern 01,011 and 0111 when the outer dimensions is bigger then 4096.
上级
791b8bcf
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
32 行增加
和
9 行删除
+32
-9
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+11
-8
test_basic_ops.py
theano/sandbox/cuda/tests/test_basic_ops.py
+21
-1
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
609593b4
...
...
@@ -1062,7 +1062,7 @@ class GpuSum(Op):
"""
%
locals
()
def
c_code_cache_version
(
self
):
return
(
1
5
,)
return
(
1
6
,)
def
c_support_code_apply
(
self
,
node
,
nodename
):
...
...
@@ -1174,7 +1174,7 @@ class GpuSum(Op):
for_i2
=
"for (int i2 = threadIdx.y; i2 < d2; i2 += blockDim.y)"
for_i3
=
"for (int i3 = threadIdx.x; i3 < d3; i3 += blockDim.x)"
reducebuf
=
self
.
_k_reduce_buf
(
'Z[
blockIdx.x
* sZ0]'
)
reducebuf
=
self
.
_k_reduce_buf
(
'Z[
i0
* sZ0]'
)
param_dim
=
","
.
join
([
"const int d
%(i)
s"
%
locals
()
for
i
in
range
(
nd_in
)])
param_strides
=
","
.
join
([
"const int sA
%(i)
s"
%
locals
()
for
i
in
range
(
nd_in
)])
decl
=
self
.
_k_decl
(
node
,
nodename
)
...
...
@@ -1182,15 +1182,18 @@ class GpuSum(Op):
print
>>
sio
,
"""
%(decl)
s{
%(init)
s
%(for_i1)
s{
%(for_i2)
s{
%(for_i3)
s{
float Ai = A[i3 * sA3 + i2 * sA2 + i1 * sA1 + blockIdx.x * sA0];
mysum += Ai;
for (int i0 = blockIdx.x; i0 < d0; i0 += gridDim.x){
mysum = 0;
%(for_i1)
s{
%(for_i2)
s{
%(for_i3)
s{
float Ai = A[i3 * sA3 + i2 * sA2 + i1 * sA1 + i0 * sA0];
mysum += Ai;
}
}
}
%(reducebuf)
s
}
%(reducebuf)
s
}
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
0
):
...
...
theano/sandbox/cuda/tests/test_basic_ops.py
浏览文件 @
609593b4
...
...
@@ -38,7 +38,27 @@ def test_sum():
((
0
,
0
),[
0
,
1
]),((
1
,
0
),[
0
,
1
]),((
5
,
4
),[
0
,
1
]),((
33
,
31
),[
0
,
1
]),((
5
,
4
),[
1
]),((
5
,
4
),[
0
]),
#need something bigger then 32 for some opt test.
((
5
,
4
,
3
),[
0
]),((
5
,
4
,
3
),[
0
,
1
]),((
5
,
4
,
3
),[
2
]),((
5
,
4
,
3
),[
1
,
2
]),((
5
,
4
,
3
),[
0
,
1
,
2
]),
((
0
,
0
,
0
,
0
),[
0
,
1
,
2
,
3
]),
((
5
,
4
,
3
,
20
),[
2
,
3
]),
((
5
,
4
,
3
,
2
),[
0
,
1
,
2
,
3
]),
((
5
,
4
,
3
,
2
),[
0
,
2
,
3
]),((
5
,
4
,
3
,
2
),[
1
,
2
,
3
])]:
((
5
,
4
,
3
,
20
),[
2
,
3
]),
((
5
,
4
,
3
,
2
),[
0
,
1
,
2
,
3
]),
((
5
,
4
,
3
,
2
),[
0
,
2
,
3
]),((
5
,
4
,
3
,
2
),[
1
,
2
,
3
]),
#test shape bigger then 4096 on each dimension to make sure that we work correctly when we don't have enought thread/block in each dimensions
((
4100
,
3
),[
0
]),((
3
,
4101
),[
0
]),
#10
((
4100
,
3
),[
1
]),((
3
,
4101
),[
1
]),
#01
((
4100
,
3
),[
0
,
1
]),((
3
,
4101
),[
0
,
1
]),
#11
((
4100
,
4
,
3
),[
0
]),((
5
,
4100
,
3
),[
0
]),((
5
,
4
,
4100
),[
0
]),
#100
#((4100,4,3),[1]),((5,4100,3),[1]),((5,4,4100),[1]),#010 ##not implemented
((
4100
,
4
,
3
),[
2
]),((
5
,
4100
,
3
),[
2
]),((
5
,
4
,
4100
),[
2
]),
#001
((
4100
,
4
,
3
),[
0
,
1
]),((
5
,
4100
,
3
),[
0
,
1
]),((
5
,
4
,
4100
),[
0
,
1
]),
#110
((
4100
,
4
,
3
),[
1
,
2
]),((
5
,
4100
,
3
),[
1
,
2
]),((
5
,
4
,
4100
),[
1
,
2
]),
#011
#((4100,4,3),[0,2]),((5,4100,3),[0,2]),((5,4,4100),[0,2]),#101 ##not implemented
((
4100
,
4
,
3
),[
0
,
1
,
2
]),((
5
,
4100
,
3
),[
0
,
1
,
2
]),((
5
,
4
,
4100
),[
0
,
1
,
2
]),
#111
((
4100
,
4
,
3
,
2
),[
2
,
3
]),((
4
,
4100
,
3
,
2
),[
2
,
3
]),((
4
,
3
,
4100
,
2
),[
2
,
3
]),((
4
,
3
,
2
,
4100
),[
2
,
3
]),
#0011
((
4100
,
4
,
3
,
2
),[
0
,
2
,
3
]),((
4
,
4100
,
3
,
2
),[
0
,
2
,
3
]),((
4
,
3
,
4100
,
2
),[
0
,
2
,
3
]),
#((4,3,2,4100),[0,2,3]),#1011
((
4100
,
4
,
3
,
2
),[
1
,
2
,
3
]),((
4
,
4100
,
3
,
2
),[
1
,
2
,
3
]),((
4
,
3
,
4100
,
2
),[
1
,
2
,
3
]),((
4
,
3
,
2
,
4100
),[
1
,
2
,
3
]),
#0111
((
4100
,
2
,
3
,
4
),[
0
,
1
,
2
,
3
]),((
2
,
4100
,
3
,
4
),[
0
,
1
,
2
,
3
]),((
2
,
3
,
4100
,
4
),[
0
,
1
,
2
,
3
]),((
2
,
3
,
4
,
4100
),[
0
,
1
,
2
,
3
]),
#1111
]:
a
=
tensor
.
TensorType
(
'float32'
,(
False
,)
*
len
(
shape
))()
b
=
T
.
Sum
(
pattern
)(
a
)
val
=
numpy
.
random
.
rand
(
numpy
.
prod
(
shape
))
.
reshape
(
shape
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论