Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
42927b2b
提交
42927b2b
authored
1月 31, 2011
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
差异文件
merge
上级
af5a7fe6
c237dc5c
全部展开
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
112 行增加
和
51 行删除
+112
-51
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+11
-8
blas.py
theano/sandbox/cuda/blas.py
+3
-3
conv.cu
theano/sandbox/cuda/conv.cu
+0
-0
conv_kernel.cu
theano/sandbox/cuda/conv_kernel.cu
+47
-30
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+8
-2
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+5
-0
test_conv_cuda_ndarray.py
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+38
-8
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
42927b2b
...
...
@@ -453,7 +453,6 @@ class GpuSum(Op):
PyErr_Format(PyExc_RuntimeError, "Failed to allocate output");
%(fail)
s;
}
}
"""
%
locals
()
...
...
@@ -472,12 +471,10 @@ class GpuSum(Op):
#TODO: check if we are ccontiguous when we un-dimshuffle
#TODO: if only some dims are ccontiguous, call version with less dims.
print
>>
sio
,
'if(CudaNdarray_is_c_contiguous(
%(x)
s)){'
%
locals
()
self
.
c_code_reduce_ccontig
(
sio
,
node
,
name
,
x
,
z
,
fail
)
print
>>
sio
,
"}else{"
getattr
(
self
,
'c_code_reduce_
%
s'
%
(
''
.
join
(
str
(
i
)
for
i
in
self
.
reduce_mask
)))(
sio
,
node
,
name
,
x
,
z
,
fail
)
print
>>
sio
,
"}"
else
:
getattr
(
self
,
'c_code_reduce_
%
s'
%
(
''
.
join
(
str
(
i
)
for
i
in
self
.
reduce_mask
)))(
sio
,
node
,
name
,
x
,
z
,
fail
)
...
...
@@ -826,8 +823,16 @@ class GpuSum(Op):
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(
%(x)
s)[0],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
dim3 n_blocks(1,CudaNdarray_HOST_DIMS(
%(x)
s)[1]);
if (verbose) printf("running kernel_reduce_sum_10_
%(name)
s
\\
n");
dim3 n_blocks(1,
std::min(CudaNdarray_HOST_DIMS(
%(x)
s)[1],
NUM_VECTOR_OP_BLOCKS));
if (verbose) {
fprintf(stderr,
"running kernel_reduce_sum_10_
%(name)
s n_blocks=(
%%
i,
%%
i)
\\
n",
n_blocks.x,
n_blocks.y);
}
assert( CudaNdarray_HOST_DIMS(
%(x)
s)[1] == CudaNdarray_HOST_DIMS(
%(z)
s)[0]);
int n_shared = sizeof(float) * n_threads.x;
kernel_reduce_sum_010_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
1,
...
...
@@ -1175,9 +1180,7 @@ class GpuSum(Op):
"""
%
locals
()
def
c_code_cache_version
(
self
):
#return ()
return
(
19
,)
return
(
20
,)
def
c_support_code_apply
(
self
,
node
,
nodename
):
sio
=
StringIO
.
StringIO
()
...
...
theano/sandbox/cuda/blas.py
浏览文件 @
42927b2b
...
...
@@ -363,9 +363,10 @@ class GpuConv(Op):
return
[
'cuda_ndarray.cuh'
,
'<stdio.h>'
]
def
c_code_cache_version
(
self
):
return
(
0
,
8
)
return
(
0
,
9
)
# raise this whenever modifying any of the support_code_files
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of these files
return
open
(
os
.
path
.
join
(
os
.
path
.
split
(
__file__
)[
0
],
'conv_kernel.cu'
))
.
read
()
+
\
open
(
os
.
path
.
join
(
os
.
path
.
split
(
__file__
)[
0
],
'conv_full_kernel.cu'
))
.
read
()
+
\
open
(
os
.
path
.
join
(
os
.
path
.
split
(
__file__
)[
0
],
'conv.cu'
))
.
read
()
...
...
@@ -405,8 +406,7 @@ class GpuConv(Op):
CudaNdarray * out2 = (CudaNdarray *)CudaNdarray_Conv(
%(img)
s,
%(kern)
s,
%(out)
s,
mode, dx, dy, version, verbose);
if(
%(out)
s &&
%(out)
s==out2)
Py_DECREF(out2);//CudaNdarray_Conv incremented the count to out
Py_XDECREF(
%(out)
s);
%(out)
s = out2;
"""
%
sub
...
...
theano/sandbox/cuda/conv.cu
浏览文件 @
42927b2b
差异被折叠。
点击展开。
theano/sandbox/cuda/conv_kernel.cu
浏览文件 @
42927b2b
// REMEMBER TO RAISE c_code_cache_version when changing this file
//
//implement the valid convolution only
/*
...
...
@@ -38,6 +40,8 @@ for (int iter_m=0; iter_m < Os[0]; iter_m++) {
#define BS(i, j) Bs[i][j]
#endif
*/
#define MAX(a,b) ((a)>(b)?(a):(b))
#define MIN(a,b) ((a)<(b)?(a):(b))
const unsigned long int COALESCED_ALIGN = 0xFFFFFFFFFFFFFF00; // zero-out the trailing bits of pointers
#define MASKED_OFFSET(src) (((int)((unsigned long int)src - (((unsigned long int)src) & COALESCED_ALIGN))) / sizeof(float))
...
...
@@ -46,7 +50,8 @@ __device__ void load_to_shared(float * dst, const float * src, const int thread_
if (nb_thread < 64)
{
if(flipped)
//TODO very slow on device before 1.3. make access to kern sequential and access to d_kern flipped.
//TODO very slow on device before 1.3.
// make access to kern sequential and access to d_kern flipped.
for(int i=thread_id;i<N;i+=nb_thread)
dst[i]=src[N - 1 - i];
//dst[N-1-i]=src[i];
...
...
@@ -88,10 +93,9 @@ __device__ void load_to_shared(float * dst, const float * src, const int thread_
const bool flipped=false, const bool c_contiguous=true){
if(flipped && ! c_contiguous){
for(int i=thread_id;i<nb_row*nb_col;i+=nb_thread)
dst[nb_row*nb_col-1-i]=src[
i/nb_col*stride_row+i%nb_col
*stride_col];
dst[nb_row*nb_col-1-i]=src[
(i/nb_col)*stride_row+(i%nb_col)
*stride_col];
}else if(c_contiguous){
load_to_shared(dst, src, thread_id, nb_thread, nb_col*nb_row, flipped);
}else if(flipped){//c_contiguous==true
//TODO very slow on device before 1.3. make access to kern sequential and access to d_kern flipped.
int N=nb_col*nb_row;
...
...
@@ -440,10 +444,12 @@ conv_patch_stack_reduce( float* img, float* kern, float* out,
int kern_stride_col, int kern_stride_row,
int kern_stride_stack, int kern_stride_nkern)
{
int __shared__ out_len, out_wid, nb_thread_id;
out_len = img_len - kern_len + 1;
out_wid = img_wid - kern_wid + 1;
nb_thread_id = blockDim.z*blockDim.y*blockDim.x;
//int __shared__ out_len, out_wid, nb_thread_id;
//out_len = img_len - kern_len + 1;
//out_wid = img_wid - kern_wid + 1;
const int out_wid = blockDim.x;
const int out_len = blockDim.y;
const int nb_thread_id = blockDim.z*blockDim.y*blockDim.x;
extern __shared__ float s_data[];
...
...
@@ -458,9 +464,16 @@ conv_patch_stack_reduce( float* img, float* kern, float* out,
int out_row = ty;//output row
const int thread_id = tz*blockDim.y*blockDim.x+ty*blockDim.x+tx;
float * d_img=&s_data[0];//size of [IMAGE_LEN * IMAGE_WID];
float * d_kern=&s_data[img_len * img_wid];//size of [(preload_full_kern?KERNEL_LEN:blockDim.z) * KERNEL_WID];
float * d_reduce=&s_data[img_len*img_wid+(preload_full_kern?kern_len:blockDim.z)*kern_wid];
//d_img size [IMAGE_LEN * IMAGE_WID];
float * d_img=&s_data[0];
//d_kern size[(preload_full_kern?KERNEL_LEN:blockDim.z) * KERNEL_WID]
float * d_kern=&s_data[img_len * img_wid];
//d_reduce size [n_threads]
//N.B. this overlaps with d_img and d_kern!
float * d_reduce=&s_data[0];
float sum = 0.0f;
kern+=kern_stride_nkern*blockIdx.y;//the good nkern
...
...
@@ -471,30 +484,31 @@ conv_patch_stack_reduce( float* img, float* kern, float* out,
__syncthreads();
load_to_shared(d_img, img, thread_id, nb_thread_id, img_wid, img_len,
img_stride_col, img_stride_row, false, c_contiguous);
if(!(split && ! preload_full_kern))
load_to_shared(d_kern, kern, thread_id, nb_thread_id, kern_wid, kern_len,
kern_stride_col, kern_stride_row, flipped_kern, c_contiguous);
__syncthreads();
if(split && ! preload_full_kern){
for(int first_row=0, row=tz;first_row<kern_len;row+=blockDim.z, first_row+=blockDim.z){
int idx3;
//TODO: test/check for flipped_kern
if(flipped_kern)
idx3=(kern_len-(first_row)-blockDim.z);//the current last row flipped
else
idx3=first_row;
for(int first_row=0;first_row<kern_len;first_row+=blockDim.z){
//N.B. - Jan 30, 2011 with CUDA 3.2 I found that without the explicit cast to
// (int)blockDim.z, idx3 would sometimes be negative. I'm rusty on my signed vs. unsigned
// details, but that seemed really weird. tricky bug to find too.
int idx3 = flipped_kern
? max((kern_len - (int)blockDim.z - first_row),0)
: first_row;
int len3 = min(blockDim.z, kern_len - first_row);
__syncthreads();
load_to_shared(d_kern, kern+idx3*kern_stride_row, thread_id, nb_thread_id, kern_wid,
blockDim.z
,
load_to_shared(d_kern, kern+idx3*kern_stride_row, thread_id, nb_thread_id, kern_wid,
len3
,
kern_stride_col, kern_stride_row, flipped_kern, c_contiguous);
__syncthreads();
const float* idx_kern=&d_kern[tz*kern_
stride_row
];
const float* idx_in=&d_img[(
row
+out_row)*img_wid+out_col];
const float* idx_kern=&d_kern[tz*kern_
wid
];
const float* idx_in=&d_img[(
first_row+tz
+out_row)*img_wid+out_col];
float sum2 = 0;
if(
row<kern_len
)
if(
tz<len3
)
convolutionRowNoFlip<KERN_WIDTH>(sum2,idx_in,idx_kern,kern_wid);
sum+=sum2;
}
}else if(split){
load_to_shared(d_kern, kern, thread_id, nb_thread_id, kern_wid, kern_len,
kern_stride_col, kern_stride_row, flipped_kern, c_contiguous);
__syncthreads();
for(int row=tz;row<kern_len;row+=blockDim.z){
const float* idx_kern=&d_kern[row*kern_wid];
const float* idx_in=&d_img[(row+out_row)*img_wid+out_col];
...
...
@@ -504,18 +518,21 @@ conv_patch_stack_reduce( float* img, float* kern, float* out,
int row = tz;//The row of the kernel.
const float* idx_kern=&d_kern[row*kern_wid];
const float* idx_in=&d_img[(row+out_row)*img_wid+out_col];
load_to_shared(d_kern, kern, thread_id, nb_thread_id, kern_wid, kern_len,
kern_stride_col, kern_stride_row, flipped_kern, c_contiguous);
__syncthreads();
convolutionRowNoFlip<KERN_WIDTH>(sum,idx_in,idx_kern,kern_wid);
}
__syncthreads(); // ensure calculations have completed before any thread starts changing the shared memory
}
//reduce
//reduce
no sync because previous loop ends with sync
d_reduce[thread_id]=sum;
__syncthreads();
if(thread_id<out_len*out_wid){
sum=0;
for(int i=
0
;i<blockDim.z;i++){
sum+=d_reduce[thread_id+i*
blockDim.x*blockDim.y
];
if(thread_id<out_len*out_wid){
// blockDim.x==out_wid, blockDim.y==out_len
//
sum=0;
for(int i=
1
;i<blockDim.z;i++){
sum+=d_reduce[thread_id+i*
out_wid*out_len
];
}
out[batch_id*out_wid*out_len*nkern+//the good batch
out_wid*out_len*blockIdx.y+//the output image
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
42927b2b
...
...
@@ -134,7 +134,9 @@ CudaNdarray_uninit(CudaNdarray*self)
assert(self->devdata);
if (device_free(self->devdata))
{
std::cerr << "!!!! error freeing device memory\n";
fprintf(stderr,
"!!!! error freeing device memory %p (self=%p)\n",
self->devdata, self);
rval = -1;
}
self->devdata = NULL;
...
...
@@ -144,7 +146,9 @@ CudaNdarray_uninit(CudaNdarray*self)
{
if (device_free(self->dev_structure))
{
std::cerr << "!!!! error freeing device memory\n";
fprintf(stderr,
"!!!! error freeing dev_structure memory %p (self=%p)\n",
self->dev_structure, self);
rval = -1;
}
self->dev_structure = NULL;
...
...
@@ -1848,6 +1852,8 @@ CudaNdarray_ptr_int_size(PyObject* _unused, PyObject* args)
}
get_gpu_ptr_size<<<1,1>>>(gpu_data);
if (cudaSuccess != cublasGetError()){
device_free(gpu_data);
return PyErr_Format(PyExc_RuntimeError,
"CudaNdarray_ptr_int_size: error when calling the gpu code.");
}
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
42927b2b
...
...
@@ -403,6 +403,11 @@ int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd, const inttype
self->devdata = 0;
return -1;
}
if (0)
fprintf(stderr,
"Allocated devdata %p (self=%p)\n",
self->devdata,
self);
self->data_allocated = size;
}
return 0;
...
...
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
浏览文件 @
42927b2b
...
...
@@ -84,14 +84,36 @@ def py_conv_scipy(img, kern, mode, subsample):
def
_params_allgood_header
():
print
"ishape kshape #Mflops CPU Mflops GPU Mflops Speedup"
def
_params_allgood
(
ishape
,
kshape
,
mode
,
subsample
=
(
1
,
1
),
img_stride
=
(
1
,
1
),
kern_stride
=
(
1
,
1
),
version
=-
1
,
verbose
=
0
,
random
=
True
,
print_
=
None
,
id
=
None
,
rtol
=
1e-5
,
atol
=
1e-8
,
nb_iter
=
0
,
ones
=
False
):
def
test_example
():
# Test a specific configuration that was failing in one of the big unit-tests
# This configuration information was read from one of the 'FAIL' lines printed by
# _params_allgood during a nosetest run
#
# now it can be tested directly by nosetests test_conv_cuda_ndarray.py:test_example
assert
_params_allgood
(
(
1
,
1
,
4
,
4
),
(
1
,
1
,
3
,
2
),
'valid'
,
version
=
13
,
random
=
False
)
def
_params_allgood
(
ishape
,
kshape
,
mode
,
subsample
=
(
1
,
1
),
img_stride
=
(
1
,
1
),
kern_stride
=
(
1
,
1
),
version
=-
1
,
verbose
=
0
,
random
=
True
,
print_
=
None
,
id
=
None
,
rtol
=
1e-5
,
atol
=
1e-8
,
nb_iter
=
0
,
ones
=
False
):
#
# This function is the core of several of the big unit-test drivers,
# but it can also be used very directly on its own to test a specific
# kind of convolution.
#
# See `test_example` (above) for an example of how to use this directly.
#
if
ones
:
assert
not
random
npy_img
=
theano
.
_asarray
(
numpy
.
ones
(
ishape
),
dtype
=
'float32'
)
npy_kern
=
-
theano
.
_asarray
(
numpy
.
ones
(
kshape
),
dtype
=
'float32'
)
elif
random
:
npy_img
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
ishape
),
dtype
=
'float32'
)
npy_kern
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
kshape
),
dtype
=
'float32'
)
npy_img
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
ishape
)
+
1
,
dtype
=
'float32'
)
npy_kern
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
kshape
)
-
2
,
dtype
=
'float32'
)
else
:
npy_img
=
theano
.
_asarray
(
numpy
.
arange
(
numpy
.
prod
(
ishape
))
.
reshape
(
ishape
),
dtype
=
'float32'
)
+
1
npy_kern
=
-
(
theano
.
_asarray
(
numpy
.
arange
(
numpy
.
prod
(
kshape
))
.
reshape
(
kshape
),
dtype
=
'float32'
)
+
1
)
...
...
@@ -155,8 +177,6 @@ def _params_allgood(ishape, kshape, mode, subsample=(1,1), img_stride=(1,1), ker
print
"max absolute diff:"
,
diffabs
.
max
(),
"avg abs diff:"
,
numpy
.
average
(
diffabs
)
print
"median abs diff:"
,
numpy
.
median
(
diffabs
),
"nb close:"
,
nb_close
,
"/"
,
diff
.
size
print
"max relatif diff:"
,
pr_diff
.
max
(),
"avg rel diff:"
,
numpy
.
average
(
pr_diff
)
print
rval
if
not
rval
and
print_
!=
False
:
if
npy_img
.
shape
[
0
]
>
5
:
print
"img"
,
npy_img
[
0
]
...
...
@@ -185,9 +205,19 @@ def exec_conv(version, shapes, verbose, random, mode, print_=None, rtol=1e-5, on
for
id
,(
ishape
,
kshape
,
subshape
,
istride
,
kstride
)
in
enumerate
(
shapes
):
ret
=
False
try
:
ret
=
_params_allgood
(
ishape
,
kshape
,
mode
,
subsample
=
subshape
,
img_stride
=
istride
,
kern_stride
=
kstride
,
version
=
ver
,
verbose
=
verbose
,
random
=
random
,
id
=
id
,
print_
=
print_
,
rtol
=
rtol
,
ones
=
ones
)
ret
=
_params_allgood
(
ishape
,
kshape
,
mode
,
subsample
=
subshape
,
img_stride
=
istride
,
kern_stride
=
kstride
,
version
=
ver
,
verbose
=
verbose
,
random
=
random
,
id
=
id
,
print_
=
print_
,
rtol
=
rtol
,
ones
=
ones
)
except
Exception
,
e
:
print
ver
,
id
,(
ishape
,
kshape
,
subshape
,
istride
,
kstride
)
print
e
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论