提交 08703e60 authored 作者: Frederic's avatar Frederic

Make tests faster in DebugMode (and more resistent to rounding accumulation and wrong error)

上级 51c02671
...@@ -272,9 +272,13 @@ class TestConv3dFFT(unittest.TestCase): ...@@ -272,9 +272,13 @@ class TestConv3dFFT(unittest.TestCase):
utt.assert_allclose(res_ref, res_fft) utt.assert_allclose(res_ref, res_fft)
def test_opt_convgrad3d_fft(self): def test_opt_convgrad3d_fft(self):
inputs_shape = (16, 20, 32, 16, 1) inputs_shape = (2, 17, 15, 16, 1)
filters_shape = (10, 6, 12, 4, 1) filters_shape = (10, 3, 7, 4, 1)
dCdH_shape = (16, 15, 21, 13, 10) dCdH_shape = (inputs_shape[0],
inputs_shape[1] - filters_shape[1] + 1,
inputs_shape[2] - filters_shape[2] + 1,
inputs_shape[3] - filters_shape[3] + 1,
filters_shape[0])
inputs_val = numpy.random.random(inputs_shape).astype('float32') inputs_val = numpy.random.random(inputs_shape).astype('float32')
dCdH_val = numpy.random.random(dCdH_shape).astype('float32') dCdH_val = numpy.random.random(dCdH_shape).astype('float32')
...@@ -302,8 +306,8 @@ class TestConv3dFFT(unittest.TestCase): ...@@ -302,8 +306,8 @@ class TestConv3dFFT(unittest.TestCase):
utt.assert_allclose(res_ref, res_fft, rtol=1e-04, atol=1e-04) utt.assert_allclose(res_ref, res_fft, rtol=1e-04, atol=1e-04)
def test_opt_convtransp3d_fft(self): def test_opt_convtransp3d_fft(self):
inputs_shape = (16, 15, 21, 12, 10) inputs_shape = (2, 9, 16, 12, 10)
filters_shape = (10, 6, 12, 4, 1) filters_shape = (10, 3, 8, 4, 1)
inputs_val = numpy.random.random(inputs_shape).astype('float32') inputs_val = numpy.random.random(inputs_shape).astype('float32')
filters_val = numpy.random.random(filters_shape).astype('float32') filters_val = numpy.random.random(filters_shape).astype('float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论