提交 b77015eb authored 作者: Frederic's avatar Frederic

Made DownsampleFactorMax.out_shape work correctly with tensor variable.

上级 1de55252
...@@ -65,7 +65,8 @@ class DownsampleFactorMax(Op): ...@@ -65,7 +65,8 @@ class DownsampleFactorMax(Op):
:param imgshape: the shape of a tensor of images. The last two elements are interpreted :param imgshape: the shape of a tensor of images. The last two elements are interpreted
as the number of rows, and the number of cols. as the number of rows, and the number of cols.
:type imgshape: tuple, list, or similar. :type imgshape: tuple, list, or similar of integer or
scalar Theano variable.
:param ds: downsample factor over rows and columns :param ds: downsample factor over rows and columns
:type ds: list or tuple of two ints :type ds: list or tuple of two ints
...@@ -83,10 +84,15 @@ class DownsampleFactorMax(Op): ...@@ -83,10 +84,15 @@ class DownsampleFactorMax(Op):
raise TypeError('imgshape must have at least two elements (rows, cols)') raise TypeError('imgshape must have at least two elements (rows, cols)')
r, c = imgshape[-2:] r, c = imgshape[-2:]
rval = list(imgshape[:-2])+[ r/ds[0], c/ds[1]] rval = list(imgshape[:-2])+[ r/ds[0], c/ds[1]]
if not ignore_border: if not ignore_border:
if r % ds[0]: if isinstance(r, theano.Variable):
rval[-2] = tensor.switch(r % ds[0], rval[-2] + 1, rval[-2])
elif r % ds[0]:
rval[-2] += 1 rval[-2] += 1
if c % ds[1]: if isinstance(c, theano.Variable):
rval[-1] = tensor.switch(c % ds[1], rval[-1] + 1, rval[-1])
elif c % ds[1]:
rval[-1] += 1 rval[-1] += 1
return rval return rval
......
...@@ -163,13 +163,12 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -163,13 +163,12 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
images = tensor.dtensor4() images = tensor.dtensor4()
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3)) maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3), (3, 2))
# maxpoolshps = ((1, 1), (2, 3), (3, 2))
imval = rng.rand(4, 10, 64, 64) imval = rng.rand(4, 10, 64, 64)
# imval = rng.rand(2, 3, 3, 4) # imval = rng.rand(2, 3, 3, 4)
for maxpoolshp in maxpoolshps: for maxpoolshp in maxpoolshps:
for ignore_border in [True]: for ignore_border in [True, False]:
# for ignore_border in [True, False]: print maxpoolshp, ignore_border
self._compile_and_check([images], self._compile_and_check([images],
[DownsampleFactorMax(maxpoolshp, [DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(images)], ignore_border=ignore_border)(images)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论