提交 ef256e52 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Handle negative axis in `squeeze` and do not ignore invalid axis

* Previous implementation would silently ignore explicit axis for dimensions that were not broadcastable. * Added note about difference in behavior between Aesara and Numpy functions Closes #830 Co-authored-by: 's avatarBrandon T. Willard <971601+brandonwillard@users.noreply.github.com>
上级 8f96d930
......@@ -603,7 +603,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):
def squeeze(x, axis=None):
"""
Remove broadcastable dimensions from the shape of an array.
Remove broadcastable (length 1) dimensions from the shape of an array.
It returns the input array, but with the broadcastable dimensions
removed. This is always `x` itself or a view into `x`.
......@@ -615,24 +615,34 @@ def squeeze(x, axis=None):
x :
Input data, tensor variable.
axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the
shape. If an axis is selected with shape entry greater than
one, an error is raised.
Selects a subset of broadcastable dimensions to be removed.
If a non broadcastable dimension is selected, an error is raised.
If `axis` is ``None``, all broadcastable dimensions will be removed.
Notes
-----
The behavior can differ from that of NumPy in two ways:
1. If an axis is chosen for a dimension that is not known to be broadcastable
an error is raised, even if this dimension would be broadcastable when the
variable is evaluated.
2. Similarly, if `axis` is ``None``, only dimensions known to be broadcastable will be
removed, even if there are more dimensions that happen to be broadcastable when
the variable is evaluated.
Returns
-------
`x` without its broadcastable dimensions.
`x` without `axis` dimensions.
"""
if axis is None:
axis = range(x.ndim)
# By default exclude all broadcastable (length=1) axes
axis = (i for i in range(x.ndim) if x.broadcastable[i])
elif not isinstance(axis, Collection):
axis = (axis,)
view = x.dimshuffle(
[i for i in range(x.ndim) if not x.broadcastable[i] or i not in axis]
)
return view
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=x.ndim)
return x.dimshuffle([i for i in range(x.ndim) if i not in axis])
def compress(condition, x, axis=None):
......
......@@ -435,6 +435,19 @@ class TestSqueeze(utt.InferShapeTester):
assert res.broadcastable == (False, False)
variable = TensorType(config.floatX, [True, False, True, False, True])()
res = squeeze(variable, axis=(0, -1))
assert res.broadcastable == (False, True, False)
def test_invalid_axis(self):
# Test that trying to squeeze a non broadcastable dimension raises error
variable = TensorType(config.floatX, [True, False])()
with pytest.raises(
ValueError, match="Cannot drop a non-broadcastable dimension"
):
squeeze(variable, axis=1)
class TestCompress(utt.InferShapeTester):
axis_list = [None, -1, 0, 0, 0, 1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论