提交 2be47437 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Add tests for DownsampleFactorMaxGrad average+sum

上级 a38a44a8
...@@ -316,6 +316,9 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -316,6 +316,9 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
maxpoolsizes = ((5, 3), (3, 5), (3, 3)) maxpoolsizes = ((5, 3), (3, 5), (3, 3))
stridesizes = ((3, 2), (2, 3), (3, 3)) stridesizes = ((3, 2), (2, 3), (3, 3))
paddingsizes = ((2, 2), (2, 1), (2, 2)) paddingsizes = ((2, 2), (2, 1), (2, 2))
# average_inc_pad and average_exc_pad do not
# support grad with padding
for mode in ['max', 'sum']:
for i in range(len(imgsizes)): for i in range(len(imgsizes)):
imgsize = imgsizes[i] imgsize = imgsizes[i]
imval = rng.rand(1, 1, imgsize[0], imgsize[1]) * 10.0 imval = rng.rand(1, 1, imgsize[0], imgsize[1]) * 10.0
...@@ -328,6 +331,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -328,6 +331,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
maxpoolsize, ignore_border=True, maxpoolsize, ignore_border=True,
st=stridesize, st=stridesize,
padding=paddingsize, padding=paddingsize,
mode=mode,
)(input) )(input)
utt.verify_grad(mp, [imval], rng=rng) utt.verify_grad(mp, [imval], rng=rng)
...@@ -337,13 +341,16 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -337,13 +341,16 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
imval = rng.rand(2, 3, 3, 4) * 10.0 imval = rng.rand(2, 3, 3, 4) * 10.0
# more variance means numeric gradient will be more accurate # more variance means numeric gradient will be more accurate
for maxpoolshp in maxpoolshps: for maxpoolshp, ignore_border, mode in product(maxpoolshps,
for ignore_border in [True, False]: [True, False],
# print 'maxpoolshp =', maxpoolshp ['max',
# print 'ignore_border =', ignore_border 'sum',
'average_inc_pad',
'average_exc_pad']):
def mp(input): def mp(input):
return DownsampleFactorMax(maxpoolshp, return DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(input) ignore_border=ignore_border,
mode=mode)(input)
utt.verify_grad(mp, [imval], rng=rng) utt.verify_grad(mp, [imval], rng=rng)
def test_DownsampleFactorMax_grad_st(self): def test_DownsampleFactorMax_grad_st(self):
...@@ -353,13 +360,17 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -353,13 +360,17 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
stridesizes = ((1, 1), (3, 3), (5, 7)) stridesizes = ((1, 1), (3, 3), (5, 7))
imval = rng.rand(1, 2, 16, 16) imval = rng.rand(1, 2, 16, 16)
for maxpoolshp in maxpoolshps: for maxpoolshp, ignore_border, mode, stride in product(maxpoolshps,
for ignore_border in [True, False]: [True, False],
for stride in stridesizes: ['max',
'sum',
'average_inc_pad',
'average_exc_pad'],
stridesizes):
def mp(input): def mp(input):
return DownsampleFactorMax(maxpoolshp, return DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border, ignore_border=ignore_border,
st=stride)(input) st=stride, mode=mode)(input)
utt.verify_grad(mp, [imval], rng=rng) utt.verify_grad(mp, [imval], rng=rng)
def test_DownsampleFactorMax_grad_st_extra(self): def test_DownsampleFactorMax_grad_st_extra(self):
...@@ -372,6 +383,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -372,6 +383,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
imvsizs = ((16, 16), (16, 16), (16, 16), (8, 5), imvsizs = ((16, 16), (16, 16), (16, 16), (8, 5),
(8, 5), (8, 5), (8, 5)) (8, 5), (8, 5), (8, 5))
for mode in ['max', 'sum', 'average_inc_pad', 'average_exc_pad']:
for indx in numpy.arange(len(maxpoolshps)): for indx in numpy.arange(len(maxpoolshps)):
imvsize = imvsizs[indx] imvsize = imvsizs[indx]
imval = rng.rand(1, 2, imvsize[0], imvsize[1]) imval = rng.rand(1, 2, imvsize[0], imvsize[1])
...@@ -381,7 +393,8 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -381,7 +393,8 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
def mp(input): def mp(input):
return DownsampleFactorMax(maxpoolshp, return DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border, ignore_border=ignore_border,
st=stride)(input) st=stride,
mode=mode)(input)
utt.verify_grad(mp, [imval], rng=rng) utt.verify_grad(mp, [imval], rng=rng)
def test_DownsampleFactorMaxGrad_grad(self): def test_DownsampleFactorMaxGrad_grad(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论