Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
13df5be7
提交
13df5be7
authored
3月 09, 2012
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'gpu_conv_segfault_070312'
上级
cc62c3a8
c5424eea
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
116 行增加
和
68 行删除
+116
-68
blas.py
theano/sandbox/cuda/blas.py
+1
-1
conv.cu
theano/sandbox/cuda/conv.cu
+13
-4
conv_kernel.cu
theano/sandbox/cuda/conv_kernel.cu
+102
-63
没有找到文件。
theano/sandbox/cuda/blas.py
浏览文件 @
13df5be7
...
@@ -677,7 +677,7 @@ class GpuConv(GpuOp):
...
@@ -677,7 +677,7 @@ class GpuConv(GpuOp):
return
[
'cuda_ndarray.cuh'
,
'<stdio.h>'
]
return
[
'cuda_ndarray.cuh'
,
'<stdio.h>'
]
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
0
,
1
6
)
# raise this whenever modifying any of the support_code_files
return
(
0
,
1
7
)
# raise this whenever modifying any of the support_code_files
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of these files
# REMEMBER TO RAISE c_code_cache_version when changing any of these files
...
...
theano/sandbox/cuda/conv.cu
浏览文件 @
13df5be7
...
@@ -344,10 +344,19 @@ CudaNdarray_conv_valid(const CudaNdarray *img, const CudaNdarray * kern,
...
@@ -344,10 +344,19 @@ CudaNdarray_conv_valid(const CudaNdarray *img, const CudaNdarray * kern,
int, int, int, int,
int, int, int, int,
int, int);
int, int);
#define CONV_ROWS_STACK_SPECIAL(kern_wid) \
if (0)
if(!img_contiguous_2d || !kern_contiguous_2d) f = conv_rows_stack<kern_wid, false>;\
fprintf(stderr, "IMG CONTIG %i KERN_CONTIG %i (%i %i %i) (%i %i %i)\n",
else f = conv_rows_stack<kern_wid, true>;
img_contiguous_2d, kern_contiguous_2d,
CONV_ROWS_STACK_SPECIAL(THEANO_KERN_WID);
threads.x, threads.y, threads.z,
grid.x, grid.y, grid.z);
if(!img_contiguous_2d || !kern_contiguous_2d) {
//fprintf(stderr, "using false version\n");
f = conv_rows_stack<THEANO_KERN_WID, false>;
} else {
//fprintf(stderr, "using true version\n");
f = conv_rows_stack<THEANO_KERN_WID, true>;
}
f<<< grid, threads, shared_size >>>
f<<< grid, threads, shared_size >>>
(img->devdata,
(img->devdata,
...
...
theano/sandbox/cuda/conv_kernel.cu
浏览文件 @
13df5be7
...
@@ -41,8 +41,10 @@ for (int iter_m=0; iter_m < Os[0]; iter_m++) {
...
@@ -41,8 +41,10 @@ for (int iter_m=0; iter_m < Os[0]; iter_m++) {
#endif
#endif
*/
*/
#define MIN(a, b) ((a) < (b) ? (a) : (b) )
#define MAX(a, b) ((a) < (b) ? (b) : (a) )
const unsigned long int COALESCED_ALIGN = 0xFFFFFFFFFFFFFF00; // zero-out the trailing bits of pointers
const unsigned long int COALESCED_ALIGN = 0xFFFFFFFFFFFFFF00; // zero-out the trailing bits of pointers
#define MASKED_OFFSET(src) (((int)((unsigned long int)src - (((unsigned long int)src) & COALESCED_ALIGN))) / sizeof(float))
__device__ void load_to_shared(float * dst, const float * src, const int thread_id, int nb_thread, const int N, const bool flipped=false){
__device__ void load_to_shared(float * dst, const float * src, const int thread_id, int nb_thread, const int N, const bool flipped=false){
if (nb_thread < 64)
if (nb_thread < 64)
...
@@ -54,30 +56,41 @@ __device__ void load_to_shared(float * dst, const float * src, const int thread_
...
@@ -54,30 +56,41 @@ __device__ void load_to_shared(float * dst, const float * src, const int thread_
dst[i]=src[N - 1 - i];
dst[i]=src[N - 1 - i];
//dst[N-1-i]=src[i];
//dst[N-1-i]=src[i];
else
else
for(int i=thread_id;i<N;i+=nb_thread)
{
dst[i]=src[i];
for(int i = thread_id; i < N; i += nb_thread)
{
dst[i] = src[i];
}
}
}
}
else
else
{
{
nb_thread = nb_thread & 0xFFFFFFE0; //make nb_thread a multiple of 32
nb_thread = nb_thread & 0xFFFFFFE0; //make nb_thread a multiple of 32
// Global memory:
// Global memory:
// <-------------------------------------->
// <-------------------------------------->
// A A A A A // points of
128-bit
alignment
// A A A A A // points of
256-byte
alignment
// dddddddddddddddddddddd // layout of src in global memory
// dddddddddddddddddddddd // layout of src in global memory
// |--| // masked_src_offset
//
if (thread_id < nb_thread)
if (thread_id < nb_thread)
{
{
const int masked_src_offset = MASKED_OFFSET(src);
const float * my_src_ptr = (const float *)(
for(int masked_i=thread_id; masked_i<N + masked_src_offset; masked_i+=nb_thread)
((unsigned long int)src) & COALESCED_ALIGN);
{
my_src_ptr += thread_id;
int i = masked_i - masked_src_offset;
while (my_src_ptr < src + N)
if (i >= 0)
{
if (flipped)
if (my_src_ptr >= src)
dst[N-1-i] = src[i];
{
else
int i = my_src_ptr - src;
dst[i]=src[i];
if (flipped)
}
{
dst[N - 1 - i] = *my_src_ptr;
}
else
{
dst[i] = *my_src_ptr;
}
}
my_src_ptr += nb_thread;
}
}
}
}
}
}
}
...
@@ -89,40 +102,35 @@ __device__ void load_to_shared(float * dst, const float * src, const int thread_
...
@@ -89,40 +102,35 @@ __device__ void load_to_shared(float * dst, const float * src, const int thread_
int nb_thread, const int nb_col, const int nb_row,
int nb_thread, const int nb_col, const int nb_row,
const int stride_col, const int stride_row,
const int stride_col, const int stride_row,
const bool flipped=false, const bool c_contiguous=true){
const bool flipped=false, const bool c_contiguous=true){
if(flipped && ! c_contiguous){
if (c_contiguous)
for(int i=thread_id;i<nb_row*nb_col;i+=nb_thread)
{
dst[nb_row*nb_col-1-i]=src[(i/nb_col)*stride_row+(i%nb_col)*stride_col];
load_to_shared(dst, src, thread_id, nb_thread, nb_col*nb_row, flipped);
}else if(c_contiguous){
}
load_to_shared(dst, src, thread_id, nb_thread, nb_col*nb_row, flipped);
else
}else if(flipped){//c_contiguous==true
{
//TODO very slow on device before 1.3. make access to kern sequential and access to d_kern flipped.
if (flipped)
int N=nb_col*nb_row;
{
for(int i=thread_id;i<N;i+=nb_thread)
int LAST = nb_row * nb_col - 1;
dst[i]=src[N - 1 - i];
for(int i=thread_id;i<nb_row*nb_col;i+=nb_thread)
//dst[N-1-i]=src[i];
{
}else if(c_contiguous){//flipped==false
// XXX
for(int i=thread_id;i<nb_col*nb_row;i+=nb_thread)
// THIS IS SLOW - use whatever blocks are in the the
dst[i]=src[i];
// threads to avoid division and modulo
}else{ // !flipped && !c_contiguous
dst[LAST - i] \
/*
= src[(i/nb_col)*stride_row+(i%nb_col)*stride_col];
for(int i=thread_id;i<nb_row;i+=nb_thread){
}
float* s=&src[i*stride_row];
}
float* d=&dst[i*nb_col];
else
for(int j=thread_id;j<nb_col;i+=nb_thread)
{
// dst[i*nb_col+j]=src[i*stride_row+j*stride_col];//dst[i]=src[i];
for(int i=thread_id;i<nb_row*nb_col;i+=nb_thread)
d[j]=s[j*stride_col];
{
}*/
// XXX
/* We don't do this if as nvcc 2.3 take 2 more registers when we add the if
// THIS IS SLOW - use whatever blocks are in the the
Why it do this?
// threads to avoid division and modulo
if(stride_col==1 && stride_row==nb_col)
dst[i]=src[i/nb_col*stride_row+i%nb_col*stride_col];
for(int i=thread_id;i<nb_row*nb_col;i+=nb_thread)
}
dst[i]=src[i];
}
else*/
}
for(int i=thread_id;i<nb_row*nb_col;i+=nb_thread)
dst[i]=src[i/nb_col*stride_row+i%nb_col*stride_col];
}
}
}
__device__ void fill(float * dst, int N, float value, int thread_id, int nb_thread){
__device__ void fill(float * dst, int N, float value, int thread_id, int nb_thread){
...
@@ -639,6 +647,19 @@ conv_rows_stack( float* img, float* kern, float* out,
...
@@ -639,6 +647,19 @@ conv_rows_stack( float* img, float* kern, float* out,
kern_id = blockIdx.y%nkern;
kern_id = blockIdx.y%nkern;
nb_rows = blockDim.y;
nb_rows = blockDim.y;
int rows_to_read = MIN(
kern_len + nb_rows - 1,
img_len - blockIdx.x * nb_rows);
/**
* Every thread ultimately computes one value in the output, at coordinates
* out[ batch_id, kern_id, out_row, out_col]
*
* The batch_id and kern_id are packed into blockIdx.y. out_row and out_col
* are the threadIdx.x and threadIdx.y.
*
* Every thread block deals only with one image, and one filter kernel.
*/
extern __shared__ float s_data[];
extern __shared__ float s_data[];
const int out_col = threadIdx.x;//output col
const int out_col = threadIdx.x;//output col
...
@@ -646,31 +667,49 @@ conv_rows_stack( float* img, float* kern, float* out,
...
@@ -646,31 +667,49 @@ conv_rows_stack( float* img, float* kern, float* out,
const int shared_row = threadIdx.y;
const int shared_row = threadIdx.y;
const int thread_id = threadIdx.y*blockDim.x+threadIdx.x;
const int thread_id = threadIdx.y*blockDim.x+threadIdx.x;
/*
* The kernel works by looping over channels (aka colours, aka the stack).
* On each iteration, a thread block loads one channel of all the image rows that
* it needs to use, and one channel slice of one kernel.
*/
d_img=&s_data[0];//size of [(KERN_LEN+block_len-1) * IMAGE_WID];
d_img=&s_data[0];//size of [(KERN_LEN+block_len-1) * IMAGE_WID];
d_kern=&s_data[(kern_len+nb_rows-1) * img_wid];//size of [KERNEL_LEN * KERNEL_WID];
d_kern=&s_data[(kern_len+nb_rows-1) * img_wid];//size of [KERNEL_LEN * KERNEL_WID];
float sum = 0.0f;
for (int stack = 0;stack<nstack;stack++){
int _idx=img_stride_batch*batch_id+img_stride_stack*stack;//selection the good image from the batch
float sum = 0.0f;
_idx+=(blockIdx.x*nb_rows)*img_stride_row;//select the good top row for the block of threads
for (int stack = 0; stack < nstack; stack++){
load_to_shared(d_img,img+_idx,thread_id,nb_thread_id,img_wid,kern_len+nb_rows-1,
int offset =
img_stride_col, img_stride_row, false, c_contiguous);
img_stride_batch * batch_id
_idx=kern_stride_nkern*kern_id+kern_stride_stack*stack;
+ img_stride_stack * stack
load_to_shared(d_kern, kern+_idx, thread_id, nb_thread_id, kern_wid,kern_len,
//blockIdx.x is which chunk of nb_rows this thread block deals with
+ img_stride_row * (blockIdx.x * nb_rows);
load_to_shared(
d_img, // dst
img+offset, // src
thread_id, // linear position in block
nb_thread_id, // number of threads
img_wid, // cols in image to read
rows_to_read, // number of rows to read
img_stride_col, // img[i, j, k, l] to img[i, j, k, l + 1]
img_stride_row, // img[i, j, k, l] to img[i, j, k + 1, l]
false, // flip while reading
c_contiguous);
offset = kern_stride_nkern * kern_id + kern_stride_stack * stack;
load_to_shared(d_kern, kern+offset, thread_id, nb_thread_id, kern_wid,kern_len,
kern_stride_col, kern_stride_row, true, c_contiguous);
kern_stride_col, kern_stride_row, true, c_contiguous);
__syncthreads();
__syncthreads();
for (int row=0; row < kern_len
&& out_row<out_len
; row++) {//loop over row
for (int row=0; row < kern_len; row++) {//loop over row
const float* idx_kern=&d_kern[row*kern_wid];
const float* idx_kern=&d_kern[row*kern_wid];
const float* idx_in=&d_img[(row+shared_row)*img_wid+out_col];
const float* idx_in=&d_img[(row+shared_row)*img_wid+out_col];
convolutionRowNoFlip<KERN_WIDTH>(sum,idx_in,idx_kern,kern_wid);
convolutionRowNoFlip<KERN_WIDTH>(sum,idx_in,idx_kern,kern_wid);
}
}
__syncthreads();//to be sure all thread have finished before we modif the shared memory.
__syncthreads();//to be sure all thread have finished before we modif the shared memory.
}
}
if
(out_row<
out_len)
if
(out_row <
out_len)
out[batch_id*out_wid*out_len*nkern+//the good batch
out[batch_id*out_wid*out_len*nkern+//the good batch
kern_id*out_wid*out_len+//the output image
kern_id*out_wid*out_len+//the output image
out_row*out_wid+out_col] = sum;
out_row*out_wid+out_col] = sum;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论