提交 e48ff560 authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: Ricardo Vieira

Support squeezing of unit dimension broadcastable axis

上级 044910bf
...@@ -30,6 +30,7 @@ from pytensor.tensor.math import eq as pt_eq ...@@ -30,6 +30,7 @@ from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -592,6 +593,15 @@ def squeeze(x, axis=None): ...@@ -592,6 +593,15 @@ def squeeze(x, axis=None):
# Nothing to do # Nothing to do
return _x return _x
if _x.ndim == 0:
# Nothing could be squeezed
return _x
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
_x = specify_broadcastable(_x, *non_broadcastable_axis)
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis]) return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
......
...@@ -463,14 +463,6 @@ class TestSqueeze(utt.InferShapeTester): ...@@ -463,14 +463,6 @@ class TestSqueeze(utt.InferShapeTester):
assert res.broadcastable == (False, True, False) assert res.broadcastable == (False, True, False)
def test_invalid_axis(self):
# Test that trying to squeeze a non broadcastable dimension raises error
variable = TensorType(config.floatX, shape=(1, None))()
with pytest.raises(
ValueError, match="Cannot drop a non-broadcastable dimension"
):
squeeze(variable, axis=1)
def test_scalar_input(self): def test_scalar_input(self):
x = pt.scalar("x") x = pt.scalar("x")
...@@ -482,6 +474,25 @@ class TestSqueeze(utt.InferShapeTester): ...@@ -482,6 +474,25 @@ class TestSqueeze(utt.InferShapeTester):
): ):
squeeze(x, axis=1) squeeze(x, axis=1)
def test_invalid_input(self):
x = pt.vector("x")
axis = 0
f = pytensor.function([x], pt.squeeze(x, axis))
# Test that we allow squeezing of valid non-broadcastable dimension
assert f([0]) == 0
# Test that we cannot squeeze dimensions whose length is greater than 1
error_txt_1 = re.escape("SpecifyShape: Got shape (3,), expected (1,).")
error_txt_2 = re.escape("SpecifyShape: dim 0 of input has shape 3, expected 1")
match = error_txt_1 if pytensor.config.mode == "FAST_COMPILE" else error_txt_2
with pytest.raises(
AssertionError,
match=match,
):
f([0, 1, 2])
class TestCompress(utt.InferShapeTester): class TestCompress(utt.InferShapeTester):
def setup_method(self): def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论