提交 f0db12e7 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Fix test precision

上级 356e260c
......@@ -41,27 +41,27 @@ class TestCorr3DMM(unittest.TestCase):
res_ref = f_ref()
res = f()
utt.assert_allclose(res_ref, res, rtol=1e-05, atol=1e-05)
utt.assert_allclose(res_ref, res)
def test_valid(self):
self.run_conv_valid(inputs_shape=(16, 20, 32, 16, 1),
self.run_conv_valid(inputs_shape=(16, 20, 12, 16, 1),
filters_shape=(10, 6, 12, 4, 1))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1),
self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1),
subsample=(2, 2, 2))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1),
self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1),
subsample=(2, 2, 2))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1),
self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1),
subsample=(3, 3, 3))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1),
self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1),
subsample=(3, 3, 3))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1),
self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1),
subsample=(3, 2, 1))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1),
self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1),
subsample=(1, 2, 3))
......@@ -90,24 +90,24 @@ class TestCorr3DMM(unittest.TestCase):
res_ref = f_ref()
res = f()
utt.assert_allclose(res_ref, res, rtol=1e-04, atol=1e-04)
utt.assert_allclose(res_ref, res)
def test_gradweight(self):
self.run_gradweight(inputs_shape=(16, 20, 32, 16, 1),
self.run_gradweight(inputs_shape=(16, 10, 12, 16, 1),
filters_shape=(10, 6, 12, 4, 1),
dCdH_shape=(16, 15, 21, 13, 10),
dCdH_shape=(16, 5, 1, 13, 10),
subsample=(1, 1, 1))
self.run_gradweight(inputs_shape=(16, 20, 32, 16, 1),
filters_shape=(10, 6, 12, 4, 1),
dCdH_shape=(16, 8, 11, 7, 10),
self.run_gradweight(inputs_shape=(16, 20, 10, 16, 1),
filters_shape=(10, 6, 4, 4, 1),
dCdH_shape=(16, 8, 4, 7, 10),
subsample=(2, 2, 2))
self.run_gradweight(inputs_shape=(16, 20, 32, 16, 1),
filters_shape=(10, 6, 12, 4, 1),
dCdH_shape=(16, 5, 7, 5, 10),
self.run_gradweight(inputs_shape=(16, 20, 10, 16, 1),
filters_shape=(10, 6, 3, 4, 1),
dCdH_shape=(16, 5, 3, 5, 10),
subsample=(3, 3, 3))
self.run_gradweight(inputs_shape=(16, 20, 32, 16, 1),
self.run_gradweight(inputs_shape=(16, 20, 12, 16, 1),
filters_shape=(10, 6, 12, 4, 1),
dCdH_shape=(16, 8, 21, 5, 10),
dCdH_shape=(16, 8, 1, 5, 10),
subsample=(2, 1, 3))
def run_gradinput(self, inputs_shape, filters_shape,
......@@ -140,18 +140,18 @@ class TestCorr3DMM(unittest.TestCase):
f = theano.function([], conv_gemm)
res = f()
utt.assert_allclose(res_ref, res, rtol=1e-04, atol=1e-04)
utt.assert_allclose(res_ref, res)
def test_gradinput(self):
self.run_gradinput(inputs_shape=(16, 15, 21, 12, 10),
self.run_gradinput(inputs_shape=(16, 15, 12, 12, 10),
filters_shape=(10, 6, 12, 4, 1))
self.run_gradinput(inputs_shape=(16, 15, 21, 12, 10),
self.run_gradinput(inputs_shape=(16, 15, 12, 12, 10),
filters_shape=(10, 6, 12, 4, 1),
subsample=(2,2,2))
self.run_gradinput(inputs_shape=(16, 15, 21, 12, 10),
self.run_gradinput(inputs_shape=(16, 15, 12, 12, 10),
filters_shape=(10, 6, 12, 4, 1),
subsample=(3,3,3))
self.run_gradinput(inputs_shape=(16, 15, 21, 12, 10),
self.run_gradinput(inputs_shape=(16, 15, 12, 12, 10),
filters_shape=(10, 6, 12, 4, 1),
subsample=(3,1,2))
......@@ -183,9 +183,9 @@ class TestCorr3DMM(unittest.TestCase):
utt.assert_allclose(res_ref, res_gemm)
def test_opt_convgrad3d_gemm(self):
inputs_shape = (16, 20, 32, 16, 1)
inputs_shape = (16, 10, 12, 16, 1)
filters_shape = (10, 6, 12, 4, 1)
dCdH_shape = (16, 15, 21, 13, 10)
dCdH_shape = (16, 5, 1, 13, 10)
inputs_val = numpy.random.random(inputs_shape).astype('float32')
dCdH_val = numpy.random.random(dCdH_shape).astype('float32')
......@@ -207,11 +207,11 @@ class TestCorr3DMM(unittest.TestCase):
res_ref = f_ref()
res_gemm = f_gemm()
utt.assert_allclose(res_ref, res_gemm, rtol=1e-04, atol=1e-04)
utt.assert_allclose(res_ref, res_gemm)
def test_opt_convtransp3d_gemm(self):
inputs_shape = (16, 15, 21, 12, 10)
inputs_shape = (16, 15, 12, 12, 10)
filters_shape = (10, 6, 12, 4, 1)
inputs_val = numpy.random.random(inputs_shape).astype('float32')
......@@ -234,5 +234,5 @@ class TestCorr3DMM(unittest.TestCase):
res_ref = f_ref()
res_gemm = f_gemm()
utt.assert_allclose(res_ref, res_gemm, rtol=1e-04, atol=1e-04)
utt.assert_allclose(res_ref, res_gemm)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论