Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c3e013da
提交
c3e013da
authored
3月 10, 2010
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added faster GpuSum case when we sum on all dimensions on a ccontiguous tensor.
上级
b2cccdd2
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
96 行增加
和
2 行删除
+96
-2
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+73
-2
test_basic_ops.py
theano/sandbox/cuda/tests/test_basic_ops.py
+23
-0
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
c3e013da
...
...
@@ -473,7 +473,20 @@ class GpuSum(Op):
#
# Now perform the reduction
#
getattr
(
self
,
'c_code_reduce_
%
s'
%
(
''
.
join
(
str
(
i
)
for
i
in
self
.
reduce_mask
)))(
sio
,
node
,
name
,
x
,
z
,
fail
)
if
all
(
i
==
1
for
i
in
self
.
reduce_mask
):
#check if the tensor is ccontiguous, if true, use the c_c0de_reduce_ccontig code.
#TODO: check if we are ccontiguous when we un-dimshuffle
#TODO: if only some dims are ccontiguous, call version with less dims.
print
>>
sio
,
'if(CudaNdarray_is_c_contiguous(
%(x)
s)){'
%
locals
()
self
.
c_code_reduce_ccontig
(
sio
,
node
,
name
,
x
,
z
,
fail
)
print
>>
sio
,
"}else{"
getattr
(
self
,
'c_code_reduce_
%
s'
%
(
''
.
join
(
str
(
i
)
for
i
in
self
.
reduce_mask
)))(
sio
,
node
,
name
,
x
,
z
,
fail
)
print
>>
sio
,
"}"
else
:
getattr
(
self
,
'c_code_reduce_
%
s'
%
(
''
.
join
(
str
(
i
)
for
i
in
self
.
reduce_mask
)))(
sio
,
node
,
name
,
x
,
z
,
fail
)
return
sio
.
getvalue
()
...
...
@@ -639,6 +652,37 @@ class GpuSum(Op):
}
"""
%
locals
()
def
c_code_reduce_ccontig
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
print
>>
sio
,
"""
{
int verbose = 0;
dim3 n_threads(
std::min(CudaNdarray_SIZE(
%(x)
s),
NUM_VECTOR_OP_THREADS_PER_BLOCK));
dim3 n_blocks(1);
if (verbose) printf("running kernel_reduce_sum_ccontig_
%(name)
s
\\
n");
int n_shared = sizeof(float) * n_threads.x * n_threads.y * n_threads.z;
kernel_reduce_sum_ccontig_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
CudaNdarray_SIZE(
%(x)
s),//need SIZE here as we use this kernel for ccontiguous tensor
CudaNdarray_DEV_DATA(
%(x)
s),
CudaNdarray_DEV_DATA(
%(z)
s));
CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s. (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kernel_reduce_sum_ccontig_
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)
s;
}
}
"""
%
locals
()
def
c_code_reduce_1
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
print
>>
sio
,
"""
{
...
...
@@ -935,11 +979,38 @@ class GpuSum(Op):
def
c_code_cache_version
(
self
):
#return ()
return
(
8
,)
return
(
9
,)
def
c_support_code_apply
(
self
,
node
,
nodename
):
sio
=
StringIO
.
StringIO
()
if
all
(
i
==
1
for
i
in
self
.
reduce_mask
):
#this kernel is ok for up to a few thousand elements, but
# it only runs on ONE multiprocessor
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
)
print
>>
sio
,
"""
static __global__ void kernel_reduce_sum_ccontig_
%(nodename)
s(
const unsigned int d0,
const float *A,
float * Z)
{
const int threadCount = blockDim.x;
const int threadNum = threadIdx.x;
extern __shared__ float buf[];
float mysum = 0.0f;
if (warpSize != 32)
{
return; //TODO: set error code
}
for (int i0 = threadIdx.x; i0 < d0; i0 += blockDim.x)
{
mysum += A[i0];
}
%(reducebuf)
s
}
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,):
#this kernel is ok for up to a few thousand elements, but
# it only runs on ONE multiprocessor
...
...
theano/sandbox/cuda/tests/test_basic_ops.py
浏览文件 @
c3e013da
...
...
@@ -52,6 +52,29 @@ def test_sum():
assert
numpy
.
allclose
(
f2
(
val
),
f
(
val
))
#test with dimshuffle
#we shuffle the 2 outer dims.
for
shape
,
pattern
in
[
#((5,),[0]),
((
5
,
4
),[
0
,
1
]),((
5
,
4
),[
0
]),
((
5
,
4
,
3
),[
0
]),((
5
,
4
,
3
),[
0
,
1
]),((
5
,
4
,
3
),[
2
]),((
5
,
4
,
3
),[
0
,
1
,
2
]),
((
5
,
4
,
3
,
2
),[
0
,
1
,
2
,
3
]),
((
5
,
4
,
3
,
2
),[
0
,
2
,
3
])]:
a
=
tensor
.
TensorType
(
'float32'
,(
False
,)
*
len
(
shape
))()
dim_pattern
=
range
(
len
(
shape
))
dim_pattern
[
0
]
=
1
dim_pattern
[
1
]
=
0
a
=
a
.
dimshuffle
(
dim_pattern
)
b
=
T
.
Sum
(
pattern
)(
a
)
val
=
numpy
.
random
.
rand
(
numpy
.
prod
(
shape
))
.
reshape
(
shape
)
# val = numpy.ones(shape)
# val = numpy.arange(numpy.prod(shape)).reshape(shape)
val
=
theano
.
_asarray
(
val
,
dtype
=
'float32'
)
f
=
theano
.
function
([
a
],
b
,
mode
=
mode_with_gpu
)
f2
=
theano
.
function
([
a
],
b
,
mode
=
mode_without_gpu
)
assert
tcn
.
GpuSum
in
[
x
.
op
.
__class__
for
x
in
f
.
maker
.
env
.
toposort
()]
assert
T
.
Sum
in
[
x
.
op
.
__class__
for
x
in
f2
.
maker
.
env
.
toposort
()]
assert
numpy
.
allclose
(
f2
(
val
),
f
(
val
))
#test with broadcast
for
shape
,
pattern
in
[((
5
,),[
0
]),
((
5
,
4
),[
0
,
1
]),((
5
,
4
),[
0
]),
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论