提交 a6a9f6fb authored 作者: Eric Larsen's avatar Eric Larsen

testing infer_shape: op DownSampleFactorMaxGrad

上级 b77015eb
......@@ -2,7 +2,8 @@ import unittest, sys, time
import numpy
import theano.tensor as tensor
from theano.tests import unittest_tools as utt
from theano.tensor.signal.downsample import DownsampleFactorMax, max_pool_2d
from theano.tensor.signal.downsample import (DownsampleFactorMax, max_pool_2d,
DownsampleFactorMaxGrad)
from theano import function, Mode
......@@ -158,21 +159,51 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
def test_infer_shape(self):
## TODO: maxpoolshp != (1, 1) fails with ignore_border == False
# see function out_shape in class DownsampleFactorMax
images = tensor.dtensor4()
image = tensor.dtensor4()
maxout = tensor.dtensor4()
gz = tensor.dtensor4()
rng = numpy.random.RandomState(utt.fetch_seed())
maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3), (3, 2))
imval = rng.rand(4, 10, 64, 64)
# imval = rng.rand(2, 3, 3, 4)
for maxpoolshp in maxpoolshps:
for ignore_border in [True, False]:
print maxpoolshp, ignore_border
self._compile_and_check([images],
image_val = rng.rand(2, 3, 3, 4)
out_shapes = [[[2, 3, 3, 4], [2, 3, 3, 4]],
[[2, 3, 1, 2], [2, 3, 2, 2]],
[[2, 3, 1, 1], [2, 3, 1, 2]],
[[2, 3, 1, 1], [2, 3, 2, 2]],
[[2, 3, 1, 2], [2, 3, 1, 2]]]
"""
image_val = rng.rand(4, 6, 7, 9)
out_shapes = [[[4, 6, 7, 9], [4, 6, 7, 9]],
[[4, 6, 3, 4], [4, 6, 4, 5]],
[[4, 6, 2, 3], [4, 6, 3, 3]],
[[4, 6, 3, 3], [4, 6, 4, 3]],
[[4, 6, 2, 4], [4, 6, 3, 5]]]
image_val = rng.rand(4, 10, 64, 64)
out_shapes = [[[4, 10, 64, 64], [4, 10, 64, 64]],
[[4, 10, 32, 32], [4, 10, 32, 32]],
[[4, 10, 21, 21], [4, 10, 22, 22]],
[[4, 10, 32, 21], [4, 10, 32, 22]],
[[4, 10, 21, 32], [4, 10, 22, 32]]]
"""
for i, maxpoolshp in enumerate(maxpoolshps):
for j, ignore_border in enumerate([True, False]):
# checking shapes generated by DownsampleFactorMax
self._compile_and_check([image],
[DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(images)],
[imval], DownsampleFactorMax)
ignore_border=ignore_border)(image)],
[image_val], DownsampleFactorMax)
# checking shapes generated by DownsampleFactorMaxGrad
maxout_val = rng.rand(*out_shapes[i][j])
gz_val = rng.rand(*out_shapes[i][j])
self._compile_and_check([image, maxout, gz],
[DownsampleFactorMaxGrad(maxpoolshp,
ignore_border=ignore_border)(image, maxout, gz)],
[image_val, maxout_val, gz_val],
DownsampleFactorMaxGrad)
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论