提交 71486d94 authored 作者: Frederic Bastien's avatar Frederic Bastien

refactored convolution test on the gpu to don't execution version of code that…

refactored convolution test on the gpu to don't execution version of code that will execute the refenrece version because the optimization can't be applyed.
上级 59204ff9
...@@ -224,7 +224,8 @@ def get_shapes2(scales_img=(1,1), scales_kern=(1,1), subsample=(1,1), img_stride ...@@ -224,7 +224,8 @@ def get_shapes2(scales_img=(1,1), scales_kern=(1,1), subsample=(1,1), img_stride
(2*scales_kern[0],3*scales_kern[1]),subsample, img_stride, kern_stride) (2*scales_kern[0],3*scales_kern[1]),subsample, img_stride, kern_stride)
return shapes return shapes
def test_valid(): def get_valid_shapes():
# img shape, kern shape, subsample shape # img shape, kern shape, subsample shape
shapes = get_basic_shapes() shapes = get_basic_shapes()
...@@ -271,11 +272,169 @@ def test_valid(): ...@@ -271,11 +272,169 @@ def test_valid():
, ((20,10,29,29),(30,10,23,23), (1,1), (1,1), (1,1))#test_lenet_64 bprop 1 , ((20,10,29,29),(30,10,23,23), (1,1), (1,1), (1,1))#test_lenet_64 bprop 1
, ((1,10,64,64),(20,10,58,58), (1,1), (1,1), (1,1))#test_lenet_64 bprop 2 , ((1,10,64,64),(20,10,58,58), (1,1), (1,1), (1,1))#test_lenet_64 bprop 2
] ]
return shapes
def test_valid_0_2():
shapes = get_valid_shapes()
version=[0,2]
verbose=0
random = True
print_ = False
ones = False
if ones:
random = False
shapes2=[]
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
continue
if ishape[1]>1:
continue
if (numpy.prod(ishape[2:])+numpy.prod(kshape[2:]))*4>(16*1024-150):
continue
if subshape==(1,1):
shapes2.append((ishape, kshape, subshape, istride, kstride))
shapes = shapes2
exec_conv(version, shapes, verbose, random, 'valid', print_=print_, ones=ones, rtol=1.1e-5)
def test_valid_1_3_11_12():
shapes = get_valid_shapes()
version=[1,3,11,12]
verbose=0
random = True
print_ = False
ones = False
if ones:
random = False
shapes2=[]
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
continue
if (numpy.prod(ishape[2:])+numpy.prod(kshape[2:]))*4>(16*1024-150):
continue
if subshape==(1,1):
shapes2.append((ishape, kshape, subshape, istride, kstride))
shapes = shapes2
exec_conv(version, shapes, verbose, random, 'valid', print_=print_, ones=ones, rtol=1.1e-5)
def test_valid_4():
shapes = get_valid_shapes()
version=[4]
verbose=0
random = True
print_ = False
ones = False
if ones:
random = False
shapes2=[]
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
continue
if ishape[1]>1:
continue
if (kshape[2]*ishape[3]*4+numpy.prod(kshape[2:])*4)>(16*1024-150):
continue
if subshape==(1,1):
shapes2.append((ishape, kshape, subshape, istride, kstride))
shapes = shapes2
exec_conv(version, shapes, verbose, random, 'valid', print_=print_, ones=ones, rtol=1.1e-5)
def test_valid_5():
shapes = get_valid_shapes()
version=[5]
verbose=0
random = True
print_ = False
ones = False
if ones:
random = False
shapes2=[]
print len(shapes)
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
continue
if (kshape[2]*ishape[3]*4+numpy.prod(kshape[2:])*4)>(16*1024-150):
continue
if subshape==(1,1):
shapes2.append((ishape, kshape, subshape, istride, kstride))
shapes = shapes2
print len(shapes2)
exec_conv(version, shapes, verbose, random, 'valid', print_=print_, ones=ones, rtol=1.1e-5)
def test_valid_7_8_13():
shapes = get_valid_shapes()
version=[7,8,13]
verbose=0
random = True
print_ = False
ones = False
if ones:
random = False
shapes2=[]
print len(shapes)
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
continue
if (numpy.prod(ishape[2:])*4+2*kshape[3]*4+oshape[2]*oshape[3]*4*2)>(16*1024-150):
continue
if subshape==(1,1):
shapes2.append((ishape, kshape, subshape, istride, kstride))
shapes = shapes2
print len(shapes2)
exec_conv(version, shapes, verbose, random, 'valid', print_=print_, ones=ones, rtol=1.1e-5)
def test_valid_9_10():
shapes = get_valid_shapes()
version=[9,10]
verbose=0
random = True
print_ = False
ones = False
if ones:
random = False
shapes2=[]
print len(shapes)
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
continue
if (kshape[3]*4+ishape[3])>(16*1024-150):
continue
if subshape==(1,1):
shapes2.append((ishape, kshape, subshape, istride, kstride))
shapes = shapes2
print len(shapes2)
exec_conv(version, shapes, verbose, random, 'valid', print_=print_, ones=ones, rtol=1.1e-5)
def test_valid():
shapes = get_valid_shapes()
#shapes=shapes[400:426] #shapes=shapes[400:426]
# I put -1 in case we forget to add version in the test to. # I put -1 in case we forget to add version in the test to.
# I put -2 to test the reference version. # I put -2 to test the reference version.
version=[-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13] version=[-2,-1,6]
verbose=0 verbose=0
# version=[1] # version=[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论