提交 dc2e2eee authored 作者: Frederic's avatar Frederic

Update the test to cover the new case for each version of gpu convolution.

上级 9d3a2736
...@@ -31,6 +31,16 @@ else: ...@@ -31,6 +31,16 @@ else:
cuda_tensor4 = cuda_ndarray.CudaNdarrayType([False] * 4) cuda_tensor4 = cuda_ndarray.CudaNdarrayType([False] * 4)
device_id = theano.sandbox.cuda.use.device_number
if device_id is None:
cuda_ndarray.shared_constructor(numpy.zeros(2, dtype='float32'))
device_id = theano.sandbox.cuda.use.device_number
device_id = device_id[3:]
if device_id == '':
device_id = 0
cuda_ndarray = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray
device_prop = cuda_ndarray.device_properties(device_id)
def py_conv_valid_numpy(img, kern): def py_conv_valid_numpy(img, kern):
assert img.shape[1] == kern.shape[1] assert img.shape[1] == kern.shape[1]
...@@ -386,7 +396,7 @@ def test_valid_0_2(): ...@@ -386,7 +396,7 @@ def test_valid_0_2():
oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) - oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) -
numpy.asarray(kshape[2:]) + numpy.asarray(kshape[2:]) +
numpy.asarray([1, 1])) numpy.asarray([1, 1]))
if oshape[3] > 512: if oshape[3] > device_prop['maxThreadsDim0']:
continue continue
if ishape[1] > 1: if ishape[1] > 1:
continue continue
...@@ -417,7 +427,7 @@ def test_valid_1_3_11_12(): ...@@ -417,7 +427,7 @@ def test_valid_1_3_11_12():
oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) - oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) -
numpy.asarray(kshape[2:]) + numpy.asarray(kshape[2:]) +
numpy.asarray([1, 1])) numpy.asarray([1, 1]))
if oshape[3] > 512: if oshape[3] > device_prop['maxThreadsDim0']:
continue continue
if ((numpy.prod(ishape[2:]) + numpy.prod(kshape[2:])) * 4 > if ((numpy.prod(ishape[2:]) + numpy.prod(kshape[2:])) * 4 >
(16 * 1024 - 150)): (16 * 1024 - 150)):
...@@ -446,7 +456,7 @@ def test_valid_4(): ...@@ -446,7 +456,7 @@ def test_valid_4():
oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) - oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) -
numpy.asarray(kshape[2:]) + numpy.asarray(kshape[2:]) +
numpy.asarray([1, 1])) numpy.asarray([1, 1]))
if oshape[3] > 512: if oshape[3] > device_prop['maxThreadsDim0']:
continue continue
if ishape[1] > 1: if ishape[1] > 1:
continue continue
...@@ -478,7 +488,7 @@ def test_valid_5(): ...@@ -478,7 +488,7 @@ def test_valid_5():
oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) - oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) -
numpy.asarray(kshape[2:]) + numpy.asarray(kshape[2:]) +
numpy.asarray([1, 1])) numpy.asarray([1, 1]))
if oshape[3] > 512: if oshape[3] > device_prop['maxThreadsDim0']:
continue continue
if ((kshape[2] * ishape[3] * 4 + numpy.prod(kshape[2:]) * 4) > if ((kshape[2] * ishape[3] * 4 + numpy.prod(kshape[2:]) * 4) >
(16 * 1024 - 150)): (16 * 1024 - 150)):
...@@ -512,7 +522,7 @@ def test_valid_7_8_13(): ...@@ -512,7 +522,7 @@ def test_valid_7_8_13():
oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) - oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) -
numpy.asarray(kshape[2:]) + numpy.asarray(kshape[2:]) +
numpy.asarray([1, 1])) numpy.asarray([1, 1]))
if oshape[2] * oshape[3] > 512: if oshape[2] * oshape[3] > device_prop['maxThreadsDim0']:
continue continue
if max(numpy.prod(ishape[2:]) * 4 + 2 * kshape[3] * 4, if max(numpy.prod(ishape[2:]) * 4 + 2 * kshape[3] * 4,
oshape[2] * oshape[3] * 4 * 2) > (16 * 1024 - 150): oshape[2] * oshape[3] * 4 * 2) > (16 * 1024 - 150):
...@@ -543,7 +553,7 @@ def test_valid_9_10(): ...@@ -543,7 +553,7 @@ def test_valid_9_10():
oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) - oshape = [ishape[0]] + [kshape[0]] + list(numpy.asarray(ishape[2:]) -
numpy.asarray(kshape[2:]) + numpy.asarray(kshape[2:]) +
numpy.asarray([1, 1])) numpy.asarray([1, 1]))
if oshape[3] > 512: if oshape[3] > device_prop['maxThreadsDim0']:
continue continue
if (kshape[3] * 4 + ishape[3]) > (16 * 1024 - 150): if (kshape[3] * 4 + ishape[3]) > (16 * 1024 - 150):
continue continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论