提交 80a5c158 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Remove npy<2 compatibility for normalize_axis_{index,tuple}

上级 df999e20
......@@ -4,6 +4,7 @@ from textwrap import dedent, indent
import numba
import numpy as np
from numba.core.extending import overload
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
from numpy.lib.stride_tricks import as_strided
from pytensor.graph.op import Op
......@@ -19,7 +20,6 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
store_core_outputs,
)
from pytensor.link.utils import compile_function_src
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.scalar.basic import (
AND,
OR,
......
......@@ -3,15 +3,6 @@ from textwrap import dedent
import numpy as np
# Conditional numpy imports for numpy 1.26 and 2.x compatibility
try:
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
except ModuleNotFoundError:
# numpy < 2.0
from numpy.core.multiarray import normalize_axis_index # type: ignore[no-redef]
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
try:
from numpy._core.einsumfunc import ( # type: ignore[attr-defined]
_find_contraction,
......@@ -28,8 +19,6 @@ except ModuleNotFoundError:
__all__ = [
"_find_contraction",
"_parse_einsum_input",
"normalize_axis_index",
"normalize_axis_tuple",
]
......
......@@ -15,6 +15,7 @@ from typing import cast as type_cast
import numpy as np
from numpy.exceptions import AxisError
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
import pytensor
import pytensor.scalar.sharedvar
......@@ -31,7 +32,6 @@ from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import int32
......
......@@ -83,10 +83,10 @@ import warnings
from pathlib import Path
import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
from scipy.linalg import get_blas_funcs
from pytensor.graph import Variable, vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple
try:
......
......@@ -6,14 +6,10 @@ from itertools import pairwise
from typing import cast
import numpy as np
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
from pytensor.compile.builders import OpFromGraph
from pytensor.npy_2_compat import (
_find_contraction,
_parse_einsum_input,
normalize_axis_index,
normalize_axis_tuple,
)
from pytensor.npy_2_compat import _find_contraction, _parse_einsum_input
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import (
arange,
......
......@@ -4,6 +4,7 @@ from textwrap import dedent
from typing import Literal
import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
import pytensor.tensor.basic
from pytensor.configdefaults import config
......@@ -16,7 +17,6 @@ from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import identity as scalar_identity
......
......@@ -2,6 +2,7 @@ import warnings
from collections.abc import Collection, Iterable
import numpy as np
from numpy.lib.array_utils import normalize_axis_index
import pytensor
import pytensor.scalar.basic as ps
......@@ -18,7 +19,6 @@ from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import EnumList, Generic
from pytensor.npy_2_compat import (
normalize_axis_index,
npy_2_compat_header,
numpy_axis_is_none_flag,
old_np_unique,
......
......@@ -5,6 +5,7 @@ from textwrap import dedent
from typing import TYPE_CHECKING, Optional
import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
from pytensor import config, printing
from pytensor import scalar as ps
......@@ -13,11 +14,7 @@ from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import (
normalize_axis_tuple,
npy_2_compat_header,
numpy_axis_is_none_flag,
)
from pytensor.npy_2_compat import npy_2_compat_header, numpy_axis_is_none_flag
from pytensor.printing import pprint
from pytensor.raise_op import Assert
from pytensor.scalar.basic import BinaryScalarOp
......
......@@ -4,13 +4,13 @@ from functools import partial
from typing import Literal, cast
import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
......
......@@ -25,6 +25,7 @@ Many stabilize and stabilization rewrites refuse to be applied when a variable h
import logging
import numpy as np
from numpy.lib.array_utils import normalize_axis_index
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
......@@ -41,7 +42,6 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.npy_2_compat import normalize_axis_index
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar import (
AND,
......
......@@ -2,12 +2,12 @@ from collections.abc import Iterable, Sequence
from typing import cast
import numpy as np
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
from pytensor import Variable
from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import (
Alloc,
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Union, cast
from typing import cast as typing_cast
import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
import pytensor
from pytensor.gradient import DisconnectedType
......@@ -15,7 +16,6 @@ from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor import basic as ptb
from pytensor.tensor.exceptions import NotScalarConstantError
......
......@@ -5,10 +5,10 @@ from typing import cast
import numpy as np
from numpy import nditer
from numpy.lib.array_utils import normalize_axis_tuple
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论