提交 198b22ea authored 作者: Frederic Bastien's avatar Frederic Bastien

enable the modified gpu conv kernel that use less shared memory.

上级 e2122bf4
...@@ -363,7 +363,7 @@ class GpuConv(Op): ...@@ -363,7 +363,7 @@ class GpuConv(Op):
return ['cuda_ndarray.cuh','<stdio.h>'] return ['cuda_ndarray.cuh','<stdio.h>']
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,12) # raise this whenever modifying any of the support_code_files return (0,13) # raise this whenever modifying any of the support_code_files
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of these files # REMEMBER TO RAISE c_code_cache_version when changing any of these files
......
...@@ -116,7 +116,7 @@ CudaNdarray_conv_valid(const CudaNdarray *img, const CudaNdarray * kern, ...@@ -116,7 +116,7 @@ CudaNdarray_conv_valid(const CudaNdarray *img, const CudaNdarray * kern,
if(!subsample && if(!subsample &&
out_contiguous && out_contiguous &&
out_size<512 &&//Maximum of 512 theads by block out_size<512 &&//Maximum of 512 theads by block
(img_size_byte+2*kern_wid*sizeof(float)+out_size_byte*2)<shared_avail && //their is only 16k of shared memory and if we can't have the output at least twice in shared mem, we won't have any reduce! std::max(int(img_size_byte+2*kern_wid*sizeof(float)), out_size_byte*2)<shared_avail && //their is only 16k of shared memory and if we can't have the output at least twice in shared mem, we won't have any reduce!
!work_complete) !work_complete)
version = 7; //conv_patch_stack_reduce, switch to version 8/13 automatically if needed. version = 7; //conv_patch_stack_reduce, switch to version 8/13 automatically if needed.
} }
......
...@@ -422,6 +422,8 @@ def test_valid_5(): ...@@ -422,6 +422,8 @@ def test_valid_5():
def test_valid_7_8_13(): def test_valid_7_8_13():
shapes = get_valid_shapes() shapes = get_valid_shapes()
# This is to test the "new" lower shared memory usage.
shapes.append(((10,30,60,60),(20,30,40,40), (1,1), (1,1), (1,1))
version=[7,8,13] version=[7,8,13]
verbose=0 verbose=0
...@@ -437,7 +439,7 @@ def test_valid_7_8_13(): ...@@ -437,7 +439,7 @@ def test_valid_7_8_13():
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1])) oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[2]*oshape[3]>512: if oshape[2]*oshape[3]>512:
continue continue
if (numpy.prod(ishape[2:])*4+2*kshape[3]*4+oshape[2]*oshape[3]*4*2)>(16*1024-150): if max(numpy.prod(ishape[2:])*4+2*kshape[3]*4, oshape[2]*oshape[3]*4*2)>(16*1024-150):
continue continue
if subshape==(1,1): if subshape==(1,1):
shapes2.append((ishape, kshape, subshape, istride, kstride)) shapes2.append((ishape, kshape, subshape, istride, kstride))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论