提交 f0db12e7 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Fix test precision

上级 356e260c
...@@ -41,27 +41,27 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -41,27 +41,27 @@ class TestCorr3DMM(unittest.TestCase):
res_ref = f_ref() res_ref = f_ref()
res = f() res = f()
utt.assert_allclose(res_ref, res, rtol=1e-05, atol=1e-05) utt.assert_allclose(res_ref, res)
def test_valid(self): def test_valid(self):
self.run_conv_valid(inputs_shape=(16, 20, 32, 16, 1), self.run_conv_valid(inputs_shape=(16, 20, 12, 16, 1),
filters_shape=(10, 6, 12, 4, 1)) filters_shape=(10, 6, 12, 4, 1))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1), self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(2, 2, 2)) subsample=(2, 2, 2))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1), self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(2, 2, 2)) subsample=(2, 2, 2))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1), self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(3, 3, 3)) subsample=(3, 3, 3))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1), self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(3, 3, 3)) subsample=(3, 3, 3))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1), self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(3, 2, 1)) subsample=(3, 2, 1))
self.run_conv_valid(inputs_shape=(16, 20, 32, 15, 1), self.run_conv_valid(inputs_shape=(16, 20, 12, 15, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(1, 2, 3)) subsample=(1, 2, 3))
...@@ -90,24 +90,24 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -90,24 +90,24 @@ class TestCorr3DMM(unittest.TestCase):
res_ref = f_ref() res_ref = f_ref()
res = f() res = f()
utt.assert_allclose(res_ref, res, rtol=1e-04, atol=1e-04) utt.assert_allclose(res_ref, res)
def test_gradweight(self): def test_gradweight(self):
self.run_gradweight(inputs_shape=(16, 20, 32, 16, 1), self.run_gradweight(inputs_shape=(16, 10, 12, 16, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
dCdH_shape=(16, 15, 21, 13, 10), dCdH_shape=(16, 5, 1, 13, 10),
subsample=(1, 1, 1)) subsample=(1, 1, 1))
self.run_gradweight(inputs_shape=(16, 20, 32, 16, 1), self.run_gradweight(inputs_shape=(16, 20, 10, 16, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 4, 4, 1),
dCdH_shape=(16, 8, 11, 7, 10), dCdH_shape=(16, 8, 4, 7, 10),
subsample=(2, 2, 2)) subsample=(2, 2, 2))
self.run_gradweight(inputs_shape=(16, 20, 32, 16, 1), self.run_gradweight(inputs_shape=(16, 20, 10, 16, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 3, 4, 1),
dCdH_shape=(16, 5, 7, 5, 10), dCdH_shape=(16, 5, 3, 5, 10),
subsample=(3, 3, 3)) subsample=(3, 3, 3))
self.run_gradweight(inputs_shape=(16, 20, 32, 16, 1), self.run_gradweight(inputs_shape=(16, 20, 12, 16, 1),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
dCdH_shape=(16, 8, 21, 5, 10), dCdH_shape=(16, 8, 1, 5, 10),
subsample=(2, 1, 3)) subsample=(2, 1, 3))
def run_gradinput(self, inputs_shape, filters_shape, def run_gradinput(self, inputs_shape, filters_shape,
...@@ -140,18 +140,18 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -140,18 +140,18 @@ class TestCorr3DMM(unittest.TestCase):
f = theano.function([], conv_gemm) f = theano.function([], conv_gemm)
res = f() res = f()
utt.assert_allclose(res_ref, res, rtol=1e-04, atol=1e-04) utt.assert_allclose(res_ref, res)
def test_gradinput(self): def test_gradinput(self):
self.run_gradinput(inputs_shape=(16, 15, 21, 12, 10), self.run_gradinput(inputs_shape=(16, 15, 12, 12, 10),
filters_shape=(10, 6, 12, 4, 1)) filters_shape=(10, 6, 12, 4, 1))
self.run_gradinput(inputs_shape=(16, 15, 21, 12, 10), self.run_gradinput(inputs_shape=(16, 15, 12, 12, 10),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(2,2,2)) subsample=(2,2,2))
self.run_gradinput(inputs_shape=(16, 15, 21, 12, 10), self.run_gradinput(inputs_shape=(16, 15, 12, 12, 10),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(3,3,3)) subsample=(3,3,3))
self.run_gradinput(inputs_shape=(16, 15, 21, 12, 10), self.run_gradinput(inputs_shape=(16, 15, 12, 12, 10),
filters_shape=(10, 6, 12, 4, 1), filters_shape=(10, 6, 12, 4, 1),
subsample=(3,1,2)) subsample=(3,1,2))
...@@ -183,9 +183,9 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -183,9 +183,9 @@ class TestCorr3DMM(unittest.TestCase):
utt.assert_allclose(res_ref, res_gemm) utt.assert_allclose(res_ref, res_gemm)
def test_opt_convgrad3d_gemm(self): def test_opt_convgrad3d_gemm(self):
inputs_shape = (16, 20, 32, 16, 1) inputs_shape = (16, 10, 12, 16, 1)
filters_shape = (10, 6, 12, 4, 1) filters_shape = (10, 6, 12, 4, 1)
dCdH_shape = (16, 15, 21, 13, 10) dCdH_shape = (16, 5, 1, 13, 10)
inputs_val = numpy.random.random(inputs_shape).astype('float32') inputs_val = numpy.random.random(inputs_shape).astype('float32')
dCdH_val = numpy.random.random(dCdH_shape).astype('float32') dCdH_val = numpy.random.random(dCdH_shape).astype('float32')
...@@ -207,11 +207,11 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -207,11 +207,11 @@ class TestCorr3DMM(unittest.TestCase):
res_ref = f_ref() res_ref = f_ref()
res_gemm = f_gemm() res_gemm = f_gemm()
utt.assert_allclose(res_ref, res_gemm, rtol=1e-04, atol=1e-04) utt.assert_allclose(res_ref, res_gemm)
def test_opt_convtransp3d_gemm(self): def test_opt_convtransp3d_gemm(self):
inputs_shape = (16, 15, 21, 12, 10) inputs_shape = (16, 15, 12, 12, 10)
filters_shape = (10, 6, 12, 4, 1) filters_shape = (10, 6, 12, 4, 1)
inputs_val = numpy.random.random(inputs_shape).astype('float32') inputs_val = numpy.random.random(inputs_shape).astype('float32')
...@@ -234,5 +234,5 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -234,5 +234,5 @@ class TestCorr3DMM(unittest.TestCase):
res_ref = f_ref() res_ref = f_ref()
res_gemm = f_gemm() res_gemm = f_gemm()
utt.assert_allclose(res_ref, res_gemm, rtol=1e-04, atol=1e-04) utt.assert_allclose(res_ref, res_gemm)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论