Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c1458cc1
提交
c1458cc1
authored
7月 23, 2014
作者:
abergeron
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1988 from nouiz/gpu_red
[ENH] speed up reduction on the colums of matries.
上级
91a00e6f
c4ff9674
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
143 行增加
和
10 行删除
+143
-10
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+62
-5
test_basic_ops.py
theano/sandbox/cuda/tests/test_basic_ops.py
+9
-0
elemwise.py
theano/sandbox/gpuarray/elemwise.py
+62
-5
test_elemwise.py
theano/sandbox/gpuarray/tests/test_elemwise.py
+10
-0
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
c1458cc1
...
@@ -1265,8 +1265,64 @@ class GpuCAReduce(GpuOp):
...
@@ -1265,8 +1265,64 @@ class GpuCAReduce(GpuOp):
def
c_code_reduce_10
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
def
c_code_reduce_10
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
print
>>
sio
,
"""
print
>>
sio
,
"""
{
{
int verbose = 0;
int verbose = 0;
if(CudaNdarray_HOST_STRIDES(
%(x)
s)[0] >
CudaNdarray_HOST_STRIDES(
%(x)
s)[1]){
// If there are a lot of summations to do, then we can use simple parallelization -
// use each thread to do one sum.
// we might as well launch blocks of 32 threads because that's the warp size.
// we could schedule more threads if we were maxing out the gridsize below, but
// the gridsize is way more than the physical hardware and I think 32 threads
// on a huge grid is enough to fully use the hardware.
dim3 n_threads(32,1,1);
// We kindof reshape the input implicitly to something 4D:
// the shape A,B,C -> A, B, D, E
// where C <= D*E < C+32
// where E==32
int A = 1;
int B = CudaNdarray_HOST_DIMS(
%(x)
s)[0];
int C = CudaNdarray_HOST_DIMS(
%(x)
s)[1];
int D = C/32;
if (32*D < C) D+= 1;
assert ((C <= 32*D) && (32*D < C+32));
// The gridsize would ideally be (A, D). But we do the following logic to make
// sure we don't ask for a grid that is too big.
dim3 n_blocks(A,D);
if (n_blocks.x > NUM_VECTOR_OP_BLOCKS) n_blocks.x = NUM_VECTOR_OP_BLOCKS;
if (n_blocks.x*n_blocks.y > NUM_VECTOR_OP_BLOCKS) n_blocks.y = NUM_VECTOR_OP_BLOCKS/n_blocks.x;
kernel_reduce_010_AD_
%(name)
s<<<n_blocks, n_threads>>>(
A,B,C,D,
CudaNdarray_DEV_DATA(
%(x)
s),
1,
CudaNdarray_HOST_STRIDES(
%(x)
s)[0],
CudaNdarray_HOST_STRIDES(
%(x)
s)[1],
CudaNdarray_DEV_DATA(
%(z)
s),
1,
CudaNdarray_HOST_STRIDES(
%(z)
s)[0]
);
CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error:
%%
s:
%%
s."
" (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kernel_reduce_10_AD
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)
s;
}
}else{
dim3 n_threads(
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(
%(x)
s)[0],
std::min(CudaNdarray_HOST_DIMS(
%(x)
s)[0],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
NUM_VECTOR_OP_THREADS_PER_BLOCK));
...
@@ -1279,7 +1335,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1279,7 +1335,7 @@ class GpuCAReduce(GpuOp):
n_blocks.x,
n_blocks.x,
n_blocks.y);
n_blocks.y);
}
}
assert(
CudaNdarray_HOST_DIMS(
%(x)
s)[1] == CudaNdarray_HOST_DIMS(
%(z)
s)[0]);
assert(CudaNdarray_HOST_DIMS(
%(x)
s)[1] == CudaNdarray_HOST_DIMS(
%(z)
s)[0]);
int n_shared = sizeof(float) * n_threads.x;
int n_shared = sizeof(float) * n_threads.x;
kernel_reduce_010_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
kernel_reduce_010_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
1,
1,
...
@@ -1310,6 +1366,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1310,6 +1366,7 @@ class GpuCAReduce(GpuOp):
%(fail)
s;
%(fail)
s;
}
}
}
}
}
"""
%
locals
()
"""
%
locals
()
def
c_code_reduce_010
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
def
c_code_reduce_010
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
...
@@ -1640,7 +1697,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1640,7 +1697,7 @@ class GpuCAReduce(GpuOp):
"""
%
locals
()
"""
%
locals
()
def
c_code_cache_version_apply
(
self
,
node
):
def
c_code_cache_version_apply
(
self
,
node
):
version
=
[
9
]
# the version corresponding to the c code in this Op
version
=
[
11
]
# the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
# now we insert versions for the ops on which we depend...
scalar_node
=
Apply
(
self
.
scalar_op
,
scalar_node
=
Apply
(
self
.
scalar_op
,
...
@@ -1874,7 +1931,7 @@ class GpuCAReduce(GpuOp):
...
@@ -1874,7 +1931,7 @@ class GpuCAReduce(GpuOp):
}
}
"""
%
locals
()
"""
%
locals
()
if
self
.
reduce_mask
==
(
0
,
1
,
0
):
if
self
.
reduce_mask
==
(
0
,
1
,
0
)
or
self
.
reduce_mask
==
(
1
,
0
)
:
reduce_fct
=
self
.
_assign_reduce
(
node
,
nodename
,
"myresult"
,
reduce_fct
=
self
.
_assign_reduce
(
node
,
nodename
,
"myresult"
,
"X[a * sX0 + b * sX1 + c * sX2]"
,
"X[a * sX0 + b * sX1 + c * sX2]"
,
{},
True
)
{},
True
)
...
...
theano/sandbox/cuda/tests/test_basic_ops.py
浏览文件 @
c1458cc1
...
@@ -1269,6 +1269,15 @@ def speed_adv_sub1():
...
@@ -1269,6 +1269,15 @@ def speed_adv_sub1():
print
"ProfileMode with batch size"
,
batch_size
print
"ProfileMode with batch size"
,
batch_size
mode_with_gpu
.
print_summary
()
mode_with_gpu
.
print_summary
()
def
speed_reduce10
():
data
=
numpy
.
random
.
rand
(
1000
,
1000
)
.
astype
(
"float32"
)
m
=
theano
.
tensor
.
fmatrix
()
f
=
theano
.
function
([
m
],
[
m
.
sum
(
axis
=
0
),
m
.
T
.
sum
(
axis
=
0
)],
mode
=
mode_with_gpu
)
f
(
data
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_many_arg_elemwise
()
test_many_arg_elemwise
()
test_gpujoin_assert_cndas
()
test_gpujoin_assert_cndas
()
theano/sandbox/gpuarray/elemwise.py
浏览文件 @
c1458cc1
...
@@ -1409,8 +1409,64 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
...
@@ -1409,8 +1409,64 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
if
config
.
gpuarray
.
sync
:
if
config
.
gpuarray
.
sync
:
sync
=
"""GpuArray_sync(&
%(z)
s->ga);"""
%
locals
()
sync
=
"""GpuArray_sync(&
%(z)
s->ga);"""
%
locals
()
print
>>
sio
,
"""
print
>>
sio
,
"""
{
{
int verbose = 0;
int verbose = 0;
if(PyGpuArray_STRIDES(
%(x)
s)[0]>
PyGpuArray_STRIDES(
%(x)
s)[1]){
// If there are a lot of summations to do, then we can use simple parallelization -
// use each thread to do one sum.
// we might as well launch blocks of 32 threads because that's the warp size.
// we could schedule more threads if we were maxing out the gridsize below, but
// the gridsize is way more than the physical hardware and I think 32 threads
// on a huge grid is enough to fully use the hardware.
dim3 n_threads(32,1,1);
// We kindof reshape the input implicitly to something 4D:
// the shape A,B,C -> A, B, D, E
// where C <= D*E < C+32
// where E==32
int A = 1;
int B = PyGpuArray_DIMS(
%(x)
s)[0];
int C = PyGpuArray_DIMS(
%(x)
s)[1];
int D = C/32;
if (32*D < C) D+= 1;
assert ((C <= 32*D) && (32*D < C+32));
// The gridsize would ideally be (A, D). But we do the following logic to make
// sure we don't ask for a grid that is too big.
dim3 n_blocks(A,D);
if (n_blocks.x > 4096) n_blocks.x = 4096;
if (n_blocks.x*n_blocks.y > 4096) n_blocks.y = 4096/n_blocks.x;
kernel_reduce_010_AD_
%(name)
s<<<n_blocks, n_threads>>>(
A,B,C,D,
(
%(in_dtype)
s *)(((char *)cuda_get_ptr(
%(x)
s->ga.data))+
%(x)
s->ga.offset),
1,
PyGpuArray_STRIDES(
%(x)
s)[0]/sizeof(
%(in_dtype)
s),
PyGpuArray_STRIDES(
%(x)
s)[1]/sizeof(
%(in_dtype)
s),
(
%(out_dtype)
s *)(((char *)cuda_get_ptr(
%(z)
s->ga.data))+
%(z)
s->ga.offset),
1,
PyGpuArray_STRIDES(
%(z)
s)[0]/sizeof(
%(out_dtype)
s)
);
%(sync)
s
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error:
%%
s:
%%
s."
" (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kernel_reduce_10_AD
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)
s;
}
}else{
dim3 n_threads(
dim3 n_threads(
std::min(PyGpuArray_DIMS(
%(x)
s)[0],
std::min(PyGpuArray_DIMS(
%(x)
s)[0],
(size_t) 256));
(size_t) 256));
...
@@ -1423,7 +1479,7 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
...
@@ -1423,7 +1479,7 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
n_blocks.x,
n_blocks.x,
n_blocks.y);
n_blocks.y);
}
}
assert(
PyGpuArray_DIMS(
%(x)
s)[1] == PyGpuArray_DIMS(
%(z)
s)[0]);
assert(PyGpuArray_DIMS(
%(x)
s)[1] == PyGpuArray_DIMS(
%(z)
s)[0]);
int n_shared = sizeof(
%(acc_dtype)
s) * n_threads.x;
int n_shared = sizeof(
%(acc_dtype)
s) * n_threads.x;
kernel_reduce_010_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
kernel_reduce_010_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
1,
1,
...
@@ -1454,6 +1510,7 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
...
@@ -1454,6 +1510,7 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
%(fail)
s;
%(fail)
s;
}
}
}
}
}
"""
%
locals
()
"""
%
locals
()
def
c_code_reduce_010
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
def
c_code_reduce_010
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
...
@@ -1795,7 +1852,7 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
...
@@ -1795,7 +1852,7 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
"""
%
locals
()
"""
%
locals
()
def
c_code_cache_version_apply
(
self
,
node
):
def
c_code_cache_version_apply
(
self
,
node
):
version
=
[
1
1
]
# the version corresponding to the c code in this Op
version
=
[
1
2
]
# the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
# now we insert versions for the ops on which we depend...
scalar_node
=
Apply
(
self
.
scalar_op
,
scalar_node
=
Apply
(
self
.
scalar_op
,
...
@@ -2032,7 +2089,7 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
...
@@ -2032,7 +2089,7 @@ class GpuCAReduceCuda(HideC, CAReduceDtype):
}
}
"""
%
locals
()
"""
%
locals
()
if
self
.
reduce_mask
==
(
0
,
1
,
0
):
if
self
.
reduce_mask
==
(
0
,
1
,
0
)
or
self
.
reduce_mask
==
(
1
,
0
)
:
reduce_fct
=
self
.
_assign_reduce
(
node
,
nodename
,
"myresult"
,
reduce_fct
=
self
.
_assign_reduce
(
node
,
nodename
,
"myresult"
,
"X[a * sX0 + b * sX1 + c * sX2]"
,
"X[a * sX0 + b * sX1 + c * sX2]"
,
{},
True
)
{},
True
)
...
...
theano/sandbox/gpuarray/tests/test_elemwise.py
浏览文件 @
c1458cc1
...
@@ -167,3 +167,13 @@ class T_gpureduce_dtype(T_reduce_dtype):
...
@@ -167,3 +167,13 @@ class T_gpureduce_dtype(T_reduce_dtype):
op
=
GpuCAReduceCuda
op
=
GpuCAReduceCuda
#Currently we don't support reduction on 0 axis
#Currently we don't support reduction on 0 axis
axes
=
[
None
,
0
,
1
,
1
,
[
0
],
[
1
],
[
0
,
1
]]
axes
=
[
None
,
0
,
1
,
1
,
[
0
],
[
1
],
[
0
,
1
]]
def
speed_reduce10
():
import
numpy
import
theano
data
=
numpy
.
random
.
rand
(
1000
,
1000
)
.
astype
(
"float32"
)
m
=
theano
.
tensor
.
fmatrix
()
f
=
theano
.
function
([
m
],
[
m
.
sum
(
axis
=
0
),
m
.
T
.
sum
(
axis
=
0
)],
mode
=
mode_with_gpu
)
f
(
data
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论