提交 e2122bf4 authored 作者: Frederic Bastien's avatar Frederic Bastien

more correctly test if we should run the test.

上级 4f3e3ff3
......@@ -332,7 +332,7 @@ def test_valid_0_2():
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
if oshape[3]> 512:
continue
if ishape[1]>1:
continue
......@@ -358,7 +358,7 @@ def test_valid_1_3_11_12():
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
if oshape[3]> 512:
continue
if (numpy.prod(ishape[2:])+numpy.prod(kshape[2:]))*4>(16*1024-150):
continue
......@@ -382,7 +382,7 @@ def test_valid_4():
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
if oshape[3]> 512:
continue
if ishape[1]>1:
continue
......@@ -409,7 +409,7 @@ def test_valid_5():
print len(shapes)
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
if oshape[3]> 512:
continue
if (kshape[2]*ishape[3]*4+numpy.prod(kshape[2:])*4)>(16*1024-150):
continue
......@@ -435,7 +435,7 @@ def test_valid_7_8_13():
print len(shapes)
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
if oshape[2]*oshape[3]>512:
continue
if (numpy.prod(ishape[2:])*4+2*kshape[3]*4+oshape[2]*oshape[3]*4*2)>(16*1024-150):
continue
......@@ -461,7 +461,7 @@ def test_valid_9_10():
print len(shapes)
for id,(ishape, kshape, subshape, istride, kstride) in enumerate(shapes):
oshape=[ishape[0]]+[kshape[0]]+list(numpy.asarray(ishape[2:])-numpy.asarray(kshape[2:])+numpy.asarray([1,1]))
if oshape[1]> 512:
if oshape[3]> 512:
continue
if (kshape[3]*4+ishape[3])>(16*1024-150):
continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论