Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cf91f745
提交
cf91f745
authored
5月 18, 2011
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
A first optimized implementation of conv2d on the with subsamble. Work only for some shape.
上级
899d98b6
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
46 行增加
和
28 行删除
+46
-28
blas.py
theano/sandbox/cuda/blas.py
+1
-1
conv.cu
theano/sandbox/cuda/conv.cu
+0
-0
conv_kernel.cu
theano/sandbox/cuda/conv_kernel.cu
+18
-8
test_conv_cuda_ndarray.py
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+27
-19
没有找到文件。
theano/sandbox/cuda/blas.py
浏览文件 @
cf91f745
...
...
@@ -363,7 +363,7 @@ class GpuConv(Op):
return
[
'cuda_ndarray.cuh'
,
'<stdio.h>'
]
def
c_code_cache_version
(
self
):
return
(
0
,
1
3
)
# raise this whenever modifying any of the support_code_files
return
(
0
,
1
4
)
# raise this whenever modifying any of the support_code_files
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of these files
...
...
theano/sandbox/cuda/conv.cu
浏览文件 @
cf91f745
差异被折叠。
点击展开。
theano/sandbox/cuda/conv_kernel.cu
浏览文件 @
cf91f745
...
...
@@ -280,6 +280,8 @@ conv_patch( float* img, float* kern, float* out,
*
* nkern: the number of kernel, used to compute the output image to store the result
* nstack: the size of the stack, used to compute the image to load.
* dx: patch stride rows(1 for normal convolution)
* dy: patch stride cols(1 for normal convolution)
* template flipped_kern: if true, we "flip" the kernel as in a real convolution, else we don't
* template accumulate: if true, we add the result, else we override the result
* template KERN_WIDTH: if 0, will work for any kern_wid, else it specialyse to this kern_wid as an optimization
...
...
@@ -287,19 +289,19 @@ conv_patch( float* img, float* kern, float* out,
* template kern_c_contiguous_2d: if true, the kernel have are collon and row contiguous
* template split: if true, each thread generate more then 1 output pixel, but use more registers.
* template preload_full_kern: if true, we load the full kernel in shared memory, else, we load 1 row at a time.
* template subsample: if false, remove some computation needed when dx or dy!=1.
*/
template<bool flipped_kern, bool accumulate, int KERN_WIDTH, bool img_c_contiguous_2d, bool kern_c_contiguous_2d, bool split, bool preload_full_kern>
template<bool flipped_kern, bool accumulate, int KERN_WIDTH, bool img_c_contiguous_2d, bool kern_c_contiguous_2d, bool split, bool preload_full_kern
, bool subsample
>
__global__ void
conv_patch_stack( float* img, float* kern, float* out,
int img_len, int img_wid, int kern_len, int kern_wid,
int out_len, int out_wid,
int nkern, int nstack, int img_stride_col,int img_stride_row,
int img_stride_stack, int img_stride_batch,
int kern_stride_col, int kern_stride_row,
int kern_stride_stack, int kern_stride_nkern)
int kern_stride_stack, int kern_stride_nkern
, int dx, int dy
)
{
int __shared__ out_len, out_wid, nb_thread_id;
out_len = img_len - kern_len + 1;
out_wid = img_wid - kern_wid + 1;
int __shared__ nb_thread_id;
nb_thread_id = blockDim.z*blockDim.y*blockDim.x;
extern __shared__ float s_data[];
...
...
@@ -346,7 +348,11 @@ conv_patch_stack( float* img, float* kern, float* out,
const float* idx_kern;
if(preload_full_kern) idx_kern=&d_kern[row*kern_wid];
else idx_kern=d_kern;
const float* idx_in=&d_img[(row+out_row)*img_wid+out_col];
const float* idx_in;
if(subsample)
idx_in=&d_img[(row+out_row*dx)*img_wid+out_col*dy];
else
idx_in=&d_img[(row+out_row)*img_wid+out_col];
convolutionRowNoFlip<KERN_WIDTH>(sum,idx_in,idx_kern,kern_wid);
}
...
...
@@ -368,7 +374,7 @@ conv_patch_stack( float* img, float* kern, float* out,
//TODO: inverse the out_row and stack loop to don't load the date as frequently!
//TODO: do this happen elsewhere?
for(
int out_row=ty
;out_row<out_len_max;out_row+=blockDim.y){
for(;out_row<out_len_max;out_row+=blockDim.y){
float sum = 0.0f;
for (int stack = 0;stack<nstack;stack++){
//TODO: load only the part of the image needed or put the partial result in shared memory
...
...
@@ -397,7 +403,11 @@ conv_patch_stack( float* img, float* kern, float* out,
const float* idx_kern;
if(preload_full_kern) idx_kern=&d_kern[row*kern_wid];
else idx_kern=d_kern;
const float* idx_in=&d_img[(row+out_row)*img_wid+out_col];
const float* idx_in;
if(subsample)
idx_in=&d_img[(row+out_row*dx)*img_wid+out_col*dy];
else
idx_in=&d_img[(row+out_row)*img_wid+out_col];
//if needed as on Fermi as reading out of bound index from shared memory generate an error.
//Not needed on generation before as they worked anyway. Removing the if generate the good code
...
...
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
浏览文件 @
cf91f745
...
...
@@ -282,8 +282,7 @@ def get_valid_shapes():
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
img_stride
=
(
-
1
,
-
1
))
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
kern_stride
=
(
-
1
,
-
1
))
#test subsample
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
subsample
=
(
2
,
2
))
#test subsample done in a separate fct
shapes
+=
[
#other test
...
...
@@ -502,8 +501,7 @@ def test_full():
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
img_stride
=
(
-
1
,
-
1
))
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
kern_stride
=
(
-
1
,
-
1
))
#test subsample
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
subsample
=
(
2
,
2
))
#test subsample done in a separate fct
shapes
+=
[
#other test
...
...
@@ -552,22 +550,32 @@ def test_full():
def
test_subsample
():
# implement when
shapes
=
[
((
1
,
1
,
1
,
1
),
(
1
,
1
,
1
,
1
),
(
1
,
1
))
,
((
1
,
1
,
1
,
1
),
(
1
,
1
,
1
,
1
),
(
2
,
2
))
,
((
4
,
2
,
10
,
10
),
(
3
,
2
,
2
,
2
),
(
1
,
3
))
,
((
4
,
2
,
10
,
10
),
(
3
,
2
,
2
,
2
),
(
3
,
3
))
,
((
4
,
2
,
10
,
10
),
(
3
,
2
,
2
,
2
),
(
3
,
1
))
shapes
=
[
((
1
,
1
,
1
,
1
),
(
1
,
1
,
1
,
1
),
(
1
,
1
)
,
(
1
,
1
),
(
1
,
1
)
)
,
((
1
,
1
,
1
,
1
),
(
1
,
1
,
1
,
1
),
(
2
,
2
)
,
(
1
,
1
),
(
1
,
1
)
)
,
((
4
,
2
,
10
,
10
),
(
3
,
2
,
2
,
2
),
(
1
,
3
)
,
(
1
,
1
),
(
1
,
1
)
)
,
((
4
,
2
,
10
,
10
),
(
3
,
2
,
2
,
2
),
(
3
,
3
)
,
(
1
,
1
),
(
1
,
1
)
)
,
((
4
,
2
,
10
,
10
),
(
3
,
2
,
2
,
2
),
(
3
,
1
)
,
(
1
,
1
),
(
1
,
1
)
)
]
all_good
=
True
_params_allgood_header
()
for
ishape
,
kshape
,
ds
in
shapes
:
if
not
_params_allgood
(
ishape
,
kshape
,
'full'
,
subsample
=
ds
):
all_good
=
False
if
not
_params_allgood
(
ishape
,
kshape
,
'valid'
,
subsample
=
ds
):
all_good
=
False
assert
all_good
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
subsample
=
(
1
,
1
))
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
subsample
=
(
1
,
2
))
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
subsample
=
(
2
,
1
))
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
subsample
=
(
2
,
2
))
#We put only the version that implement the subsample to make the test faster.
version_valid
=
[
-
2
,
-
1
,
1
,
3
,
11
,
12
]
version_full
=
[
-
2
,
-
1
]
verbose
=
0
random
=
True
print_
=
False
ones
=
False
if
ones
:
random
=
False
#test
random
=
False
exec_conv
(
version_valid
,
shapes
,
verbose
,
random
,
'valid'
,
print_
=
print_
,
ones
=
ones
)
exec_conv
(
version_full
,
shapes
,
verbose
,
random
,
'full'
,
print_
=
print_
,
ones
=
ones
)
## See #616
#def test_logical_shapes():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论