提交 80a5c158 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Remove npy<2 compatibility for normalize_axis_{index,tuple}

上级 df999e20
...@@ -4,6 +4,7 @@ from textwrap import dedent, indent ...@@ -4,6 +4,7 @@ from textwrap import dedent, indent
import numba import numba
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
from numpy.lib.stride_tricks import as_strided from numpy.lib.stride_tricks import as_strided
from pytensor.graph.op import Op from pytensor.graph.op import Op
...@@ -19,7 +20,6 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( ...@@ -19,7 +20,6 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
store_core_outputs, store_core_outputs,
) )
from pytensor.link.utils import compile_function_src from pytensor.link.utils import compile_function_src
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
AND, AND,
OR, OR,
......
...@@ -3,15 +3,6 @@ from textwrap import dedent ...@@ -3,15 +3,6 @@ from textwrap import dedent
import numpy as np import numpy as np
# Conditional numpy imports for numpy 1.26 and 2.x compatibility
try:
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
except ModuleNotFoundError:
# numpy < 2.0
from numpy.core.multiarray import normalize_axis_index # type: ignore[no-redef]
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
try: try:
from numpy._core.einsumfunc import ( # type: ignore[attr-defined] from numpy._core.einsumfunc import ( # type: ignore[attr-defined]
_find_contraction, _find_contraction,
...@@ -28,8 +19,6 @@ except ModuleNotFoundError: ...@@ -28,8 +19,6 @@ except ModuleNotFoundError:
__all__ = [ __all__ = [
"_find_contraction", "_find_contraction",
"_parse_einsum_input", "_parse_einsum_input",
"normalize_axis_index",
"normalize_axis_tuple",
] ]
......
...@@ -15,6 +15,7 @@ from typing import cast as type_cast ...@@ -15,6 +15,7 @@ from typing import cast as type_cast
import numpy as np import numpy as np
from numpy.exceptions import AxisError from numpy.exceptions import AxisError
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
import pytensor import pytensor
import pytensor.scalar.sharedvar import pytensor.scalar.sharedvar
...@@ -31,7 +32,6 @@ from pytensor.graph.rewriting.db import EquilibriumDB ...@@ -31,7 +32,6 @@ from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import int32 from pytensor.scalar import int32
......
...@@ -83,10 +83,10 @@ import warnings ...@@ -83,10 +83,10 @@ import warnings
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
from scipy.linalg import get_blas_funcs from scipy.linalg import get_blas_funcs
from pytensor.graph import Variable, vectorize_graph from pytensor.graph import Variable, vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple
try: try:
......
...@@ -6,14 +6,10 @@ from itertools import pairwise ...@@ -6,14 +6,10 @@ from itertools import pairwise
from typing import cast from typing import cast
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.npy_2_compat import ( from pytensor.npy_2_compat import _find_contraction, _parse_einsum_input
_find_contraction,
_parse_einsum_input,
normalize_axis_index,
normalize_axis_tuple,
)
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
arange, arange,
......
...@@ -4,6 +4,7 @@ from textwrap import dedent ...@@ -4,6 +4,7 @@ from textwrap import dedent
from typing import Literal from typing import Literal
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
import pytensor.tensor.basic import pytensor.tensor.basic
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -16,7 +17,6 @@ from pytensor.link.c.basic import failure_code ...@@ -16,7 +17,6 @@ from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict from pytensor.misc.frozendict import frozendict
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.printing import Printer, pprint from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import identity as scalar_identity
......
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
from collections.abc import Collection, Iterable from collections.abc import Collection, Iterable
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_index
import pytensor import pytensor
import pytensor.scalar.basic as ps import pytensor.scalar.basic as ps
...@@ -18,7 +19,6 @@ from pytensor.link.c.op import COp ...@@ -18,7 +19,6 @@ from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import EnumList, Generic from pytensor.link.c.type import EnumList, Generic
from pytensor.npy_2_compat import ( from pytensor.npy_2_compat import (
normalize_axis_index,
npy_2_compat_header, npy_2_compat_header,
numpy_axis_is_none_flag, numpy_axis_is_none_flag,
old_np_unique, old_np_unique,
......
...@@ -5,6 +5,7 @@ from textwrap import dedent ...@@ -5,6 +5,7 @@ from textwrap import dedent
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
from pytensor import config, printing from pytensor import config, printing
from pytensor import scalar as ps from pytensor import scalar as ps
...@@ -13,11 +14,7 @@ from pytensor.graph.op import Op ...@@ -13,11 +14,7 @@ from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import ( from pytensor.npy_2_compat import npy_2_compat_header, numpy_axis_is_none_flag
normalize_axis_tuple,
npy_2_compat_header,
numpy_axis_is_none_flag,
)
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar.basic import BinaryScalarOp from pytensor.scalar.basic import BinaryScalarOp
......
...@@ -4,13 +4,13 @@ from functools import partial ...@@ -4,13 +4,13 @@ from functools import partial
from typing import Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
......
...@@ -25,6 +25,7 @@ Many stabilize and stabilization rewrites refuse to be applied when a variable h ...@@ -25,6 +25,7 @@ Many stabilize and stabilization rewrites refuse to be applied when a variable h
import logging import logging
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_index
from pytensor import compile, config from pytensor import compile, config
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import ViewOp
...@@ -41,7 +42,6 @@ from pytensor.graph.rewriting.basic import ( ...@@ -41,7 +42,6 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import RewriteDatabase from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.npy_2_compat import normalize_axis_index
from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar import ( from pytensor.scalar import (
AND, AND,
......
...@@ -2,12 +2,12 @@ from collections.abc import Iterable, Sequence ...@@ -2,12 +2,12 @@ from collections.abc import Iterable, Sequence
from typing import cast from typing import cast
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
from pytensor import Variable from pytensor import Variable
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.scalar import basic as ps from pytensor.scalar import basic as ps
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Union, cast ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Union, cast
from typing import cast as typing_cast from typing import cast as typing_cast
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
import pytensor import pytensor
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
...@@ -15,7 +16,6 @@ from pytensor.graph.replace import _vectorize_node ...@@ -15,7 +16,6 @@ from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
......
...@@ -5,10 +5,10 @@ from typing import cast ...@@ -5,10 +5,10 @@ from typing import cast
import numpy as np import numpy as np
from numpy import nditer from numpy import nditer
from numpy.lib.array_utils import normalize_axis_tuple
import pytensor import pytensor
from pytensor.graph import FunctionGraph, Variable from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code from pytensor.utils import hash_from_code
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论