提交 34d96121 authored 作者: Sina Honari's avatar Sina Honari

changing test_infer_shape in TestDownsampleFactorMax to include padding

上级 c6d2ab39
......@@ -622,31 +622,45 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3), (3, 2))
image_val = rng.rand(4, 6, 7, 9)
out_shapes = [[[4, 6, 7, 9], [4, 6, 7, 9]],
[[4, 6, 3, 4], [4, 6, 4, 5]],
[[4, 6, 2, 3], [4, 6, 3, 3]],
[[4, 6, 3, 3], [4, 6, 4, 3]],
[[4, 6, 2, 4], [4, 6, 3, 5]]]
out_shapes = [[[[4, 6, 7, 9], [4, 6, 7, 9]],
[[4, 6, 3, 4], [4, 6, 4, 5]],
[[4, 6, 2, 3], [4, 6, 3, 3]],
[[4, 6, 3, 3], [4, 6, 4, 3]],
[[4, 6, 2, 4], [4, 6, 3, 5]]],
[[None, None],
[[4, 6, 4, 5], None],
[[4, 6, 3, 3], None],
[[4, 6, 4, 3], None],
[[4, 6, 3, 5], None]],
[[None, None],
[None, None],
[[4, 6, 3, 4], None],
[[4, 6, 4, 4], None],
[None, None]]]
for i, maxpoolshp in enumerate(maxpoolshps):
for j, ignore_border in enumerate([True, False]):
# checking shapes generated by DownsampleFactorMax
self._compile_and_check([image],
[DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(image)],
[image_val], DownsampleFactorMax)
# checking shapes generated by DownsampleFactorMaxGrad
maxout_val = rng.rand(*out_shapes[i][j])
gz_val = rng.rand(*out_shapes[i][j])
self._compile_and_check([image, maxout, gz],
[DownsampleFactorMaxGrad(maxpoolshp,
ignore_border=ignore_border)
(image, maxout, gz)],
[image_val, maxout_val, gz_val],
DownsampleFactorMaxGrad,
warn=False)
for k, padding in enumerate([(0,0), (1,1), (1,2)]):
if out_shapes[k][i][j] == None:
continue
# checking shapes generated by DownsampleFactorMax
self._compile_and_check([image],
[DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border,
padding=padding)(image)],
[image_val], DownsampleFactorMax)
# checking shapes generated by DownsampleFactorMaxGrad
maxout_val = rng.rand(*out_shapes[k][i][j])
gz_val = rng.rand(*out_shapes[k][i][j])
self._compile_and_check([image, maxout, gz],
[DownsampleFactorMaxGrad(maxpoolshp,
ignore_border=ignore_border,
padding=padding)
(image, maxout, gz)],
[image_val, maxout_val, gz_val],
DownsampleFactorMaxGrad,
warn=False)
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论