Unverified 提交 3efb27eb authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Fix local_subtensor_of_squeeze when squeezing multiple dimensions (#1819)

* local_subtensor_of_squeeze bugfix * Test graph correctness * Follow template for test
上级 44e20d4a
...@@ -488,7 +488,11 @@ def local_subtensor_of_squeeze(fgraph, node): ...@@ -488,7 +488,11 @@ def local_subtensor_of_squeeze(fgraph, node):
# Apply indices directly on x # Apply indices directly on x
# Add empty slices on the axis that squeeze would have removed # Add empty slices on the axis that squeeze would have removed
new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None)) new_idxs = list(idxs)
for d in sorted(dropped_dims):
new_idxs.insert(d, slice(None))
new_idxs = np.array(new_idxs, dtype=object)
x_indexed = x_before_squeeze[tuple(new_idxs)] x_indexed = x_before_squeeze[tuple(new_idxs)]
# Reapply squeeze # Reapply squeeze
......
...@@ -53,6 +53,7 @@ from pytensor.tensor.rewriting.subtensor_lift import ( ...@@ -53,6 +53,7 @@ from pytensor.tensor.rewriting.subtensor_lift import (
from pytensor.tensor.shape import SpecifyShape, _shape from pytensor.tensor.shape import SpecifyShape, _shape
from pytensor.tensor.special import softmax from pytensor.tensor.special import softmax
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from tests.unittest_tools import assert_equal_computations
mode_opt = config.mode mode_opt = config.mode
...@@ -824,3 +825,35 @@ def test_local_subtensor_of_adv_subtensor(original_fn, supported): ...@@ -824,3 +825,35 @@ def test_local_subtensor_of_adv_subtensor(original_fn, supported):
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
) )
@pytest.mark.parametrize(
"original_fn, expected_fn, x_shape",
[
(
lambda x: x.squeeze(0)[0],
lambda x: x[:, 0].squeeze(0),
(1, 5, 2, 1),
),
# Regression test for https://github.com/pymc-devs/pytensor/issues/1818
# Squeeze multiple axes then index
(
lambda x: x.squeeze((0, 1, -2))[:, 0],
lambda x: x[:, :, :, :, 0].squeeze((0, 1, 3)),
(1, 1, 2, 1, 3),
),
],
)
def test_local_subtensor_of_squeeze(original_fn, expected_fn, x_shape):
rng = np.random.default_rng()
x = pt.tensor("x", shape=x_shape)
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
out = original_fn(x)
expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out)
assert_equal_computations([opt_out], [expected_opt_out])
np.testing.assert_allclose(
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论