Unverified 提交 4312d8c6 authored 作者: Copilot's avatar Copilot 提交者: GitHub

Fix pt.flip to handle negative axis correctly using normalize_axis_tuple (#1628)

* Initial plan * Fix pt.flip to handle negative axis correctly using normalize_axis_tuple Co-authored-by: 's avatarricardoV94 <28983449+ricardoV94@users.noreply.github.com> * Expand existing test_flip function to include negative axis tests Co-authored-by: 's avatarricardoV94 <28983449+ricardoV94@users.noreply.github.com> * Fix mypy error by using separate variable for normalized axis Co-authored-by: 's avatarricardoV94 <28983449+ricardoV94@users.noreply.github.com> * Fix ruff formatting issues Co-authored-by: 's avatarjessegrabowski <48652735+jessegrabowski@users.noreply.github.com> --------- Co-authored-by: 's avatarcopilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: 's avatarricardoV94 <28983449+ricardoV94@users.noreply.github.com> Co-authored-by: 's avatarjessegrabowski <48652735+jessegrabowski@users.noreply.github.com>
上级 b05acfdd
...@@ -18,7 +18,7 @@ from pytensor.graph.type import Type ...@@ -18,7 +18,7 @@ from pytensor.graph.type import Type
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import numpy_version, using_numpy_2 from pytensor.npy_2_compat import normalize_axis_tuple, numpy_version, using_numpy_2
from pytensor.printing import Printer, pprint, set_precedence from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import ( from pytensor.tensor import (
...@@ -3369,11 +3369,12 @@ def flip( ...@@ -3369,11 +3369,12 @@ def flip(
if axis is None: if axis is None:
index = ((slice(None, None, -1)),) * arr.ndim index = ((slice(None, None, -1)),) * arr.ndim
else: else:
if isinstance(axis, int): normalized_axis = normalize_axis_tuple(axis, arr.ndim)
axis = (axis,)
index = tuple( index = tuple(
[ [
slice(None, None, -1) if i in axis else slice(None, None, None) slice(None, None, -1)
if i in normalized_axis
else slice(None, None, None)
for i in range(arr.ndim) for i in range(arr.ndim)
] ]
) )
...@@ -3382,9 +3383,9 @@ def flip( ...@@ -3382,9 +3383,9 @@ def flip(
__all__ = [ __all__ = [
"take",
"flip", "flip",
"slice_at_axis",
"inc_subtensor", "inc_subtensor",
"set_subtensor", "set_subtensor",
"slice_at_axis",
"take",
] ]
...@@ -3147,6 +3147,27 @@ def test_flip(size: tuple[int]): ...@@ -3147,6 +3147,27 @@ def test_flip(size: tuple[int]):
f = pytensor.function([x_pt], z, mode="FAST_COMPILE") f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
# Test single negative axis
for axis in range(-x.ndim, 0):
expected = np.flip(x, axis=axis)
z = flip(x_pt, axis=axis)
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
# Test tuple with negative axes
if x.ndim > 1:
expected = np.flip(x, axis=(-1, -2))
z = flip(x_pt, axis=(-1, -2))
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
# Test mixed positive and negative axes
if x.ndim >= 2:
expected = np.flip(x, axis=(0, -1))
z = flip(x_pt, axis=(0, -1))
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
class TestBenchmarks: class TestBenchmarks:
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论