提交 e90f78b9 authored 作者: lamblin's avatar lamblin

Merge pull request #558 from delallea/argmax_broadcast_fix

Fixed bug with broadcastable flags of MaxAndArgmax
......@@ -69,6 +69,8 @@ Crash Fix
(Frédéric B., reported by Razvan P.)
* Fix crash under 64-bit Windows, when taking subtensors of the form a[n:]
(Pascal L., reported by Simon McGregor)
* Fixed issue with the MaxAndArgmax Op not properly preserving broadcastable
dimensions, which could typically result in optimization crashes (Olivier D.)
=============
Release Notes
......
......@@ -2187,14 +2187,19 @@ class MaxAndArgmax(Op):
axis = _as_tensor_variable(axis)
# Verify that the axis is valid.
all_axes = set()
for ax in axis.data:
if ax < 0 or ax >= x.type.ndim:
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (axis, x.type.ndim))
all_axes.add(ax)
inputs = [x, axis]
broadcastable = [False] * (x.type.ndim - len(axis.data))
# We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax.
broadcastable = [b for i, b in enumerate(x.type.broadcastable)
if i not in all_axes]
outputs = [tensor(x.type.dtype, broadcastable, name='max'),
tensor('int64', broadcastable, name='argmax')]
return Apply(self, inputs, outputs)
......
......@@ -1720,6 +1720,14 @@ class T_max_and_argmax(unittest.TestCase):
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
def test_preserve_broadcastable(self):
"""
Ensure the original broadcastable flags are preserved by Max/Argmax.
"""
x = tensor.matrix().dimshuffle('x', 0, 'x', 1, 'x')
y = x.max(axis=1)
assert y.type.broadcastable == (True, True, False, True)
class T_argmin_argmax(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论