提交 52c127c5 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2948 from SinaHonari/issues2601

changing test_infer_shape in TestDownsampleFactorMax to include padding
...@@ -622,31 +622,45 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -622,31 +622,45 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3), (3, 2)) maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3), (3, 2))
image_val = rng.rand(4, 6, 7, 9) image_val = rng.rand(4, 6, 7, 9)
out_shapes = [[[4, 6, 7, 9], [4, 6, 7, 9]], out_shapes = [[[[4, 6, 7, 9], [4, 6, 7, 9]],
[[4, 6, 3, 4], [4, 6, 4, 5]], [[4, 6, 3, 4], [4, 6, 4, 5]],
[[4, 6, 2, 3], [4, 6, 3, 3]], [[4, 6, 2, 3], [4, 6, 3, 3]],
[[4, 6, 3, 3], [4, 6, 4, 3]], [[4, 6, 3, 3], [4, 6, 4, 3]],
[[4, 6, 2, 4], [4, 6, 3, 5]]] [[4, 6, 2, 4], [4, 6, 3, 5]]],
[[None, None],
[[4, 6, 4, 5], None],
[[4, 6, 3, 3], None],
[[4, 6, 4, 3], None],
[[4, 6, 3, 5], None]],
[[None, None],
[None, None],
[[4, 6, 3, 4], None],
[[4, 6, 4, 4], None],
[None, None]]]
for i, maxpoolshp in enumerate(maxpoolshps): for i, maxpoolshp in enumerate(maxpoolshps):
for j, ignore_border in enumerate([True, False]): for j, ignore_border in enumerate([True, False]):
for k, padding in enumerate([(0,0), (1,1), (1,2)]):
# checking shapes generated by DownsampleFactorMax if out_shapes[k][i][j] == None:
self._compile_and_check([image], continue
[DownsampleFactorMax(maxpoolshp, # checking shapes generated by DownsampleFactorMax
ignore_border=ignore_border)(image)], self._compile_and_check([image],
[image_val], DownsampleFactorMax) [DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border,
# checking shapes generated by DownsampleFactorMaxGrad padding=padding)(image)],
maxout_val = rng.rand(*out_shapes[i][j]) [image_val], DownsampleFactorMax)
gz_val = rng.rand(*out_shapes[i][j])
self._compile_and_check([image, maxout, gz], # checking shapes generated by DownsampleFactorMaxGrad
[DownsampleFactorMaxGrad(maxpoolshp, maxout_val = rng.rand(*out_shapes[k][i][j])
ignore_border=ignore_border) gz_val = rng.rand(*out_shapes[k][i][j])
(image, maxout, gz)], self._compile_and_check([image, maxout, gz],
[image_val, maxout_val, gz_val], [DownsampleFactorMaxGrad(maxpoolshp,
DownsampleFactorMaxGrad, ignore_border=ignore_border,
warn=False) padding=padding)
(image, maxout, gz)],
[image_val, maxout_val, gz_val],
DownsampleFactorMaxGrad,
warn=False)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论