提交 10dcdb38 authored 作者: Frederic's avatar Frederic

Fix bad shape inference due to wrong broadcast in DownsampleFactorMax

fix gh-3452
上级 d53c3a9a
...@@ -256,7 +256,10 @@ class DownsampleFactorMax(Op): ...@@ -256,7 +256,10 @@ class DownsampleFactorMax(Op):
raise TypeError() raise TypeError()
# TODO: consider restricting the dtype? # TODO: consider restricting the dtype?
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
return gof.Apply(self, [x], [x.type()]) # If the input shape are broadcastable we can have 0 in the output shape
broad = x.broadcastable[:2] + (False, False)
out = tensor.TensorType(x.dtype, broad)
return gof.Apply(self, [x], [out()])
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, = inp x, = inp
......
...@@ -801,6 +801,16 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -801,6 +801,16 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
[image_val, maxout_val, gz_val], [image_val, maxout_val, gz_val],
MaxPoolGrad, MaxPoolGrad,
warn=False) warn=False)
# checking with broadcastable input
image = tensor.tensor(dtype=theano.config.floatX,
broadcastable=(False, False, True, True))
image_val = rng.rand(4, 6, 1, 1)
self._compile_and_check(
[image],
[DownsampleFactorMax((2, 2),
ignore_border=True,
padding=(0, 0))(image)],
[image_val], DownsampleFactorMax)
def test_opt_max_to_average(self): def test_opt_max_to_average(self):
im = theano.tensor.tensor4() im = theano.tensor.tensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论