提交 b77015eb authored 作者: Frederic's avatar Frederic

Made DownsampleFactorMax.out_shape work correctly with tensor variable.

上级 1de55252
......@@ -65,7 +65,8 @@ class DownsampleFactorMax(Op):
:param imgshape: the shape of a tensor of images. The last two elements are interpreted
as the number of rows, and the number of cols.
:type imgshape: tuple, list, or similar.
:type imgshape: tuple, list, or similar of integer or
scalar Theano variable.
:param ds: downsample factor over rows and columns
:type ds: list or tuple of two ints
......@@ -83,10 +84,15 @@ class DownsampleFactorMax(Op):
raise TypeError('imgshape must have at least two elements (rows, cols)')
r, c = imgshape[-2:]
rval = list(imgshape[:-2])+[ r/ds[0], c/ds[1]]
if not ignore_border:
if r % ds[0]:
if isinstance(r, theano.Variable):
rval[-2] = tensor.switch(r % ds[0], rval[-2] + 1, rval[-2])
elif r % ds[0]:
rval[-2] += 1
if c % ds[1]:
if isinstance(c, theano.Variable):
rval[-1] = tensor.switch(c % ds[1], rval[-1] + 1, rval[-1])
elif c % ds[1]:
rval[-1] += 1
return rval
......
......@@ -163,13 +163,12 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
images = tensor.dtensor4()
rng = numpy.random.RandomState(utt.fetch_seed())
maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3))
# maxpoolshps = ((1, 1), (2, 3), (3, 2))
maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3), (3, 2))
imval = rng.rand(4, 10, 64, 64)
# imval = rng.rand(2, 3, 3, 4)
for maxpoolshp in maxpoolshps:
for ignore_border in [True]:
# for ignore_border in [True, False]:
for ignore_border in [True, False]:
print maxpoolshp, ignore_border
self._compile_and_check([images],
[DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(images)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论