提交 af8e9166 authored 作者: Ramana.S's avatar Ramana.S

Testing the output shape

上级 c54d3d47
...@@ -195,10 +195,19 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -195,10 +195,19 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
maxpool_op = DownsampleFactorMax(maxpoolshp, maxpool_op = DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border, ignore_border=ignore_border,
mode=mode)(images) mode=mode)(images)
output_shape = DownsampleFactorMax.out_shape(imval, maxpoolshp,
ignore_border=ignore_border)
assert numpy.asarray(output_shape).shape == numpy_output_val.shape, (
"outshape is %s, calculated shape is %s"
% (numpy.asarray(output_shape).shape, numpy_output_val.shape))
f = function([images], maxpool_op) f = function([images], maxpool_op)
output_val = f(imval) output_val = f(imval)
utt.assert_allclose(output_val, numpy_output_val) utt.assert_allclose(output_val, numpy_output_val)
def test_DownsampleFactorMaxStride(self): def test_DownsampleFactorMaxStride(self):
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
maxpoolshps = ((1, 1), (3, 3), (5, 3)) maxpoolshps = ((1, 1), (3, 3), (5, 3))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论