提交 e036caf9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Update numpy deprecated imports

- replaced np.AxisError with np.exceptions.AxisError - the `numpy.core` submodule has been renamed to `numpy._core` - some parts of `numpy.core` have been moved to `numpy.lib.array_utils` Except for `AxisError`, the updated imports are conditional on the version of numpy, so the imports should work for numpy >= 1.26. The conditional imports have been added to `npy_2_compat.py`, so the imports elsewhere are unconditonal.
上级 bbe663d9
...@@ -10,8 +10,6 @@ from copy import copy ...@@ -10,8 +10,6 @@ from copy import copy
from io import StringIO from io import StringIO
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import numpy as np
from pytensor.compile.compilelock import lock_ctx from pytensor.compile.compilelock import lock_ctx
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import ( from pytensor.graph.basic import (
...@@ -33,6 +31,7 @@ from pytensor.link.c.cmodule import ( ...@@ -33,6 +31,7 @@ from pytensor.link.c.cmodule import (
from pytensor.link.c.cmodule import get_module_cache as _get_module_cache from pytensor.link.c.cmodule import get_module_cache as _get_module_cache
from pytensor.link.c.interface import CLinkerObject, CLinkerOp, CLinkerType from pytensor.link.c.interface import CLinkerObject, CLinkerOp, CLinkerType
from pytensor.link.utils import gc_helper, map_storage, raise_with_op, streamline from pytensor.link.utils import gc_helper, map_storage, raise_with_op, streamline
from pytensor.npy_2_compat import ndarray_c_version
from pytensor.utils import difference, uniq from pytensor.utils import difference, uniq
...@@ -1367,10 +1366,6 @@ class CLinker(Linker): ...@@ -1367,10 +1366,6 @@ class CLinker(Linker):
# We must always add the numpy ABI version here as # We must always add the numpy ABI version here as
# DynamicModule always add the include <numpy/arrayobject.h> # DynamicModule always add the include <numpy/arrayobject.h>
if np.lib.NumpyVersion(np.__version__) < "1.16.0a":
ndarray_c_version = np.core.multiarray._get_ndarray_c_version()
else:
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version()
sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}") sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}")
if c_compiler: if c_compiler:
sig.append("c_compiler_str=" + c_compiler.version_str()) sig.append("c_compiler_str=" + c_compiler.version_str())
......
...@@ -4,7 +4,6 @@ from textwrap import dedent, indent ...@@ -4,7 +4,6 @@ from textwrap import dedent, indent
import numba import numba
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
...@@ -19,6 +18,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( ...@@ -19,6 +18,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
store_core_outputs, store_core_outputs,
) )
from pytensor.link.utils import compile_function_src from pytensor.link.utils import compile_function_src
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
AND, AND,
OR, OR,
......
from textwrap import dedent
import numpy as np
# Conditional numpy imports for numpy 1.26 and 2.x compatibility
try:
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
except ModuleNotFoundError:
# numpy < 2.0
from numpy.core.multiarray import normalize_axis_index # type: ignore[no-redef]
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
try:
from numpy._core.einsumfunc import ( # type: ignore[attr-defined]
_find_contraction,
_parse_einsum_input,
)
except ModuleNotFoundError:
from numpy.core.einsumfunc import ( # type: ignore[no-redef]
_find_contraction,
_parse_einsum_input,
)
# suppress linting warning by "using" the imports here:
__all__ = [
"_find_contraction",
"_parse_einsum_input",
"normalize_axis_index",
"normalize_axis_tuple",
]
numpy_version_tuple = tuple(int(n) for n in np.__version__.split(".")[:2])
numpy_version = np.lib.NumpyVersion(
np.__version__
) # used to compare with version strings, e.g. numpy_version < "1.16.0"
using_numpy_2 = numpy_version >= "2.0.0rc1"
if using_numpy_2:
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version()
else:
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
if using_numpy_2:
UintOverflowError = OverflowError
else:
UintOverflowError = TypeError
def npy_2_compat_header() -> str:
"""Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x"""
return dedent("""
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_
/*
* This header is meant to be included by downstream directly for 1.x compat.
* In that case we need to ensure that users first included the full headers
* and not just `ndarraytypes.h`.
*/
#ifndef NPY_FEATURE_VERSION
#error "The NumPy 2 compat header requires `import_array()` for which " \\
"the `ndarraytypes.h` header include is not sufficient. Please " \\
"include it after `numpy/ndarrayobject.h` or similar." \\
"" \\
"To simplify inclusion, you may use `PyArray_ImportNumPy()` " \\
"which is defined in the compat header and is lightweight (can be)."
#endif
#if NPY_ABI_VERSION < 0x02000000
/*
* Define 2.0 feature version as it is needed below to decide whether we
* compile for both 1.x and 2.x (defining it gaurantees 1.x only).
*/
#define NPY_2_0_API_VERSION 0x00000012
/*
* If we are compiling with NumPy 1.x, PyArray_RUNTIME_VERSION so we
* pretend the `PyArray_RUNTIME_VERSION` is `NPY_FEATURE_VERSION`.
* This allows downstream to use `PyArray_RUNTIME_VERSION` if they need to.
*/
#define PyArray_RUNTIME_VERSION NPY_FEATURE_VERSION
/* Compiling on NumPy 1.x where these are the same: */
#define PyArray_DescrProto PyArray_Descr
#endif
/*
* Define a better way to call `_import_array()` to simplify backporting as
* we now require imports more often (necessary to make ABI flexible).
*/
#ifdef import_array1
static inline int
PyArray_ImportNumPyAPI()
{
if (NPY_UNLIKELY(PyArray_API == NULL)) {
import_array1(-1);
}
return 0;
}
#endif /* import_array1 */
/*
* NPY_DEFAULT_INT
*
* The default integer has changed, `NPY_DEFAULT_INT` is available at runtime
* for use as type number, e.g. `PyArray_DescrFromType(NPY_DEFAULT_INT)`.
*
* NPY_RAVEL_AXIS
*
* This was introduced in NumPy 2.0 to allow indicating that an axis should be
* raveled in an operation. Before NumPy 2.0, NPY_MAXDIMS was used for this purpose.
*
* NPY_MAXDIMS
*
* A constant indicating the maximum number dimensions allowed when creating
* an ndarray.
*
* NPY_NTYPES_LEGACY
*
* The number of built-in NumPy dtypes.
*/
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
#define NPY_DEFAULT_INT NPY_INTP
#define NPY_RAVEL_AXIS NPY_MIN_INT
#define NPY_MAXARGS 64
#elif NPY_ABI_VERSION < 0x02000000
#define NPY_DEFAULT_INT NPY_LONG
#define NPY_RAVEL_AXIS 32
#define NPY_MAXARGS 32
/* Aliases of 2.x names to 1.x only equivalent names */
#define NPY_NTYPES NPY_NTYPES_LEGACY
#define PyArray_DescrProto PyArray_Descr
#define _PyArray_LegacyDescr PyArray_Descr
/* NumPy 2 definition always works, but add it for 1.x only */
#define PyDataType_ISLEGACY(dtype) (1)
#else
#define NPY_DEFAULT_INT \\
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? NPY_INTP : NPY_LONG)
#define NPY_RAVEL_AXIS \\
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? -1 : 32)
#define NPY_MAXARGS \\
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? 64 : 32)
#endif
/*
* Access inline functions for descriptor fields. Except for the first
* few fields, these needed to be moved (elsize, alignment) for
* additional space. Or they are descriptor specific and are not generally
* available anymore (metadata, c_metadata, subarray, names, fields).
*
* Most of these are defined via the `DESCR_ACCESSOR` macro helper.
*/
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION || NPY_ABI_VERSION < 0x02000000
/* Compiling for 1.x or 2.x only, direct field access is OK: */
static inline void
PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size)
{
dtype->elsize = size;
}
static inline npy_uint64
PyDataType_FLAGS(const PyArray_Descr *dtype)
{
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
return dtype->flags;
#else
return (unsigned char)dtype->flags; /* Need unsigned cast on 1.x */
#endif
}
#define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\
static inline type \\
PyDataType_##FIELD(const PyArray_Descr *dtype) { \\
if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\
return (type)0; \\
} \\
return ((_PyArray_LegacyDescr *)dtype)->field; \\
}
#else /* compiling for both 1.x and 2.x */
static inline void
PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size)
{
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
((_PyArray_DescrNumPy2 *)dtype)->elsize = size;
}
else {
((PyArray_DescrProto *)dtype)->elsize = (int)size;
}
}
static inline npy_uint64
PyDataType_FLAGS(const PyArray_Descr *dtype)
{
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
return ((_PyArray_DescrNumPy2 *)dtype)->flags;
}
else {
return (unsigned char)((PyArray_DescrProto *)dtype)->flags;
}
}
/* Cast to LegacyDescr always fine but needed when `legacy_only` */
#define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\
static inline type \\
PyDataType_##FIELD(const PyArray_Descr *dtype) { \\
if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\
return (type)0; \\
} \\
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { \\
return ((_PyArray_LegacyDescr *)dtype)->field; \\
} \\
else { \\
return ((PyArray_DescrProto *)dtype)->field; \\
} \\
}
#endif
DESCR_ACCESSOR(ELSIZE, elsize, npy_intp, 0)
DESCR_ACCESSOR(ALIGNMENT, alignment, npy_intp, 0)
DESCR_ACCESSOR(METADATA, metadata, PyObject *, 1)
DESCR_ACCESSOR(SUBARRAY, subarray, PyArray_ArrayDescr *, 1)
DESCR_ACCESSOR(NAMES, names, PyObject *, 1)
DESCR_ACCESSOR(FIELDS, fields, PyObject *, 1)
DESCR_ACCESSOR(C_METADATA, c_metadata, NpyAuxData *, 1)
#undef DESCR_ACCESSOR
#if !(defined(NPY_INTERNAL_BUILD) && NPY_INTERNAL_BUILD)
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
static inline PyArray_ArrFuncs *
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
{
return _PyDataType_GetArrFuncs(descr);
}
#elif NPY_ABI_VERSION < 0x02000000
static inline PyArray_ArrFuncs *
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
{
return descr->f;
}
#else
static inline PyArray_ArrFuncs *
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
{
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
return _PyDataType_GetArrFuncs(descr);
}
else {
return ((PyArray_DescrProto *)descr)->f;
}
}
#endif
#endif /* not internal build */
#endif /* NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ */
""")
...@@ -123,7 +123,7 @@ from pytensor.tensor import slinalg ...@@ -123,7 +123,7 @@ from pytensor.tensor import slinalg
# isort: on # isort: on
# Allow accessing numpy constants from pytensor.tensor # Allow accessing numpy constants from pytensor.tensor
from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi from numpy import e, euler_gamma, inf, nan, newaxis, pi
from pytensor.tensor.basic import * from pytensor.tensor.basic import *
from pytensor.tensor.blas import batched_dot, batched_tensordot from pytensor.tensor.blas import batched_dot, batched_tensordot
......
...@@ -14,8 +14,7 @@ from typing import TYPE_CHECKING, Union ...@@ -14,8 +14,7 @@ from typing import TYPE_CHECKING, Union
from typing import cast as type_cast from typing import cast as type_cast
import numpy as np import numpy as np
from numpy.core.multiarray import normalize_axis_index from numpy.exceptions import AxisError
from numpy.core.numeric import normalize_axis_tuple
import pytensor import pytensor
import pytensor.scalar.sharedvar import pytensor.scalar.sharedvar
...@@ -32,6 +31,7 @@ from pytensor.graph.rewriting.db import EquilibriumDB ...@@ -32,6 +31,7 @@ from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import int32 from pytensor.scalar import int32
...@@ -228,7 +228,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: ...@@ -228,7 +228,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant:
elif x_.ndim > ndim: elif x_.ndim > ndim:
try: try:
x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim))) x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim)))
except np.AxisError: except AxisError:
raise ValueError( raise ValueError(
f"ndarray could not be cast to constant with {int(ndim)} dimensions" f"ndarray could not be cast to constant with {int(ndim)} dimensions"
) )
...@@ -4405,7 +4405,7 @@ def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVa ...@@ -4405,7 +4405,7 @@ def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVa
axis = (axis,) axis = (axis,)
out_ndim = len(axis) + a.ndim out_ndim = len(axis) + a.ndim
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim) axis = normalize_axis_tuple(axis, out_ndim)
if not axis: if not axis:
return a return a
......
...@@ -8,6 +8,7 @@ import warnings ...@@ -8,6 +8,7 @@ import warnings
from math import gcd from math import gcd
import numpy as np import numpy as np
from numpy.exceptions import ComplexWarning
try: try:
...@@ -2338,7 +2339,7 @@ class BaseAbstractConv(Op): ...@@ -2338,7 +2339,7 @@ class BaseAbstractConv(Op):
bval = _bvalfromboundary("fill") bval = _bvalfromboundary("fill")
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", np.ComplexWarning) warnings.simplefilter("ignore", ComplexWarning)
for b in range(img.shape[0]): for b in range(img.shape[0]):
for g in range(self.num_groups): for g in range(self.num_groups):
for n in range(output_channel_offset): for n in range(output_channel_offset):
......
...@@ -6,13 +6,14 @@ from itertools import pairwise ...@@ -6,13 +6,14 @@ from itertools import pairwise
from typing import cast from typing import cast
import numpy as np import numpy as np
from numpy.core.einsumfunc import _find_contraction, _parse_einsum_input # type: ignore
from numpy.core.numeric import ( # type: ignore from pytensor.compile.builders import OpFromGraph
from pytensor.npy_2_compat import (
_find_contraction,
_parse_einsum_input,
normalize_axis_index, normalize_axis_index,
normalize_axis_tuple, normalize_axis_tuple,
) )
from pytensor.compile.builders import OpFromGraph
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
arange, arange,
......
...@@ -4,7 +4,6 @@ from textwrap import dedent ...@@ -4,7 +4,6 @@ from textwrap import dedent
from typing import Literal from typing import Literal
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple
import pytensor.tensor.basic import pytensor.tensor.basic
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -17,6 +16,7 @@ from pytensor.link.c.basic import failure_code ...@@ -17,6 +16,7 @@ from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict from pytensor.misc.frozendict import frozendict
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.printing import Printer, pprint from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import bool as scalar_bool
...@@ -41,9 +41,6 @@ from pytensor.tensor.variable import TensorVariable ...@@ -41,9 +41,6 @@ from pytensor.tensor.variable import TensorVariable
from pytensor.utils import uniq from pytensor.utils import uniq
_numpy_ver = [int(n) for n in np.__version__.split(".")[:2]]
class DimShuffle(ExternalCOp): class DimShuffle(ExternalCOp):
""" """
Allows to reorder the dimensions of a tensor or insert or remove Allows to reorder the dimensions of a tensor or insert or remove
......
...@@ -2,7 +2,7 @@ import warnings ...@@ -2,7 +2,7 @@ import warnings
from collections.abc import Collection, Iterable from collections.abc import Collection, Iterable
import numpy as np import numpy as np
from numpy.core.multiarray import normalize_axis_index from numpy.exceptions import AxisError
import pytensor import pytensor
import pytensor.scalar.basic as ps import pytensor.scalar.basic as ps
...@@ -17,6 +17,10 @@ from pytensor.graph.op import Op ...@@ -17,6 +17,10 @@ from pytensor.graph.op import Op
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import EnumList, Generic from pytensor.link.c.type import EnumList, Generic
from pytensor.npy_2_compat import (
normalize_axis_index,
normalize_axis_tuple,
)
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast from pytensor.scalar import upcast
...@@ -596,9 +600,9 @@ def squeeze(x, axis=None): ...@@ -596,9 +600,9 @@ def squeeze(x, axis=None):
# scalar inputs are treated as 1D regarding axis in this `Op` # scalar inputs are treated as 1D regarding axis in this `Op`
try: try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) axis = normalize_axis_tuple(axis, ndim=max(1, _x.ndim))
except np.AxisError: except AxisError:
raise np.AxisError(axis, ndim=_x.ndim) raise AxisError(axis, ndim=_x.ndim)
if not axis: if not axis:
# Nothing to do # Nothing to do
......
...@@ -5,7 +5,6 @@ from textwrap import dedent ...@@ -5,7 +5,6 @@ from textwrap import dedent
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple
from pytensor import config, printing from pytensor import config, printing
from pytensor import scalar as ps from pytensor import scalar as ps
...@@ -14,6 +13,7 @@ from pytensor.graph.op import Op ...@@ -14,6 +13,7 @@ from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar.basic import BinaryScalarOp from pytensor.scalar.basic import BinaryScalarOp
......
...@@ -4,13 +4,13 @@ from functools import partial ...@@ -4,13 +4,13 @@ from functools import partial
from typing import Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
......
...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Union, cast ...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Union, cast
from typing import cast as typing_cast from typing import cast as typing_cast
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
import pytensor import pytensor
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
...@@ -16,6 +15,7 @@ from pytensor.graph.replace import _vectorize_node ...@@ -16,6 +15,7 @@ from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.scalar import int32 from pytensor.scalar import int32
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
......
...@@ -6,6 +6,7 @@ from typing import Literal, cast ...@@ -6,6 +6,7 @@ from typing import Literal, cast
import numpy as np import numpy as np
import scipy.linalg as scipy_linalg import scipy.linalg as scipy_linalg
from numpy.exceptions import ComplexWarning
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
...@@ -767,7 +768,7 @@ class ExpmGrad(Op): ...@@ -767,7 +768,7 @@ class ExpmGrad(Op):
Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T) Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", np.ComplexWarning) warnings.simplefilter("ignore", ComplexWarning)
out[0] = Y.astype(A.dtype) out[0] = Y.astype(A.dtype)
......
...@@ -18,6 +18,7 @@ from pytensor.graph.type import Type ...@@ -18,6 +18,7 @@ from pytensor.graph.type import Type
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import numpy_version, using_numpy_2
from pytensor.printing import Printer, pprint, set_precedence from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import ( from pytensor.tensor import (
...@@ -2522,6 +2523,7 @@ class AdvancedIncSubtensor1(COp): ...@@ -2522,6 +2523,7 @@ class AdvancedIncSubtensor1(COp):
numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] numpy_ver = [int(n) for n in np.__version__.split(".")[:2]]
if bool(numpy_ver < [1, 8]): if bool(numpy_ver < [1, 8]):
raise NotImplementedError raise NotImplementedError
x, y, idx = input_names x, y, idx = input_names
out = output_names[0] out = output_names[0]
copy_of_x = self.copy_of_x(x) copy_of_x = self.copy_of_x(x)
......
...@@ -3,10 +3,10 @@ from collections.abc import Sequence ...@@ -3,10 +3,10 @@ from collections.abc import Sequence
from typing import cast from typing import cast
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
import pytensor import pytensor
from pytensor.graph import FunctionGraph, Variable from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.utils import hash_from_code from pytensor.utils import hash_from_code
...@@ -236,8 +236,8 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: ...@@ -236,8 +236,8 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None:
if axis is not None: if axis is not None:
try: try:
axis = normalize_axis_tuple(axis, ndim=max(1, ndim)) axis = normalize_axis_tuple(axis, ndim=max(1, ndim))
except np.AxisError: except np.exceptions.AxisError:
raise np.AxisError(axis, ndim=ndim) raise np.exceptions.AxisError(axis, ndim=ndim)
# TODO: If axis tuple is equivalent to None, return None for more canonicalization? # TODO: If axis tuple is equivalent to None, return None for more canonicalization?
return cast(tuple, axis) return cast(tuple, axis)
...@@ -672,7 +672,7 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -672,7 +672,7 @@ class TestCAReduce(unittest_tools.InferShapeTester):
assert self.op(ps.add, axis=(-1,))(x).eval({x: 5}) == 5 assert self.op(ps.add, axis=(-1,))(x).eval({x: 5}) == 5
with pytest.raises( with pytest.raises(
np.AxisError, np.exceptions.AxisError,
match=re.escape("axis (-2,) is out of bounds for array of dimension 0"), match=re.escape("axis (-2,) is out of bounds for array of dimension 0"),
): ):
self.op(ps.add, axis=(-2,))(x) self.op(ps.add, axis=(-2,))(x)
......
...@@ -469,7 +469,7 @@ class TestSqueeze(utt.InferShapeTester): ...@@ -469,7 +469,7 @@ class TestSqueeze(utt.InferShapeTester):
assert squeeze(x, axis=(0,)).eval({x: 5}) == 5 assert squeeze(x, axis=(0,)).eval({x: 5}) == 5
with pytest.raises( with pytest.raises(
np.AxisError, np.exceptions.AxisError,
match=re.escape("axis (1,) is out of bounds for array of dimension 0"), match=re.escape("axis (1,) is out of bounds for array of dimension 0"),
): ):
squeeze(x, axis=1) squeeze(x, axis=1)
......
...@@ -49,7 +49,7 @@ class TestLoadTensor: ...@@ -49,7 +49,7 @@ class TestLoadTensor:
path = Variable(Generic(), None) path = Variable(Generic(), None)
x = load(path, "int32", (None,), mmap_mode="c") x = load(path, "int32", (None,), mmap_mode="c")
fn = function([path], x) fn = function([path], x)
assert isinstance(fn(self.filename), np.core.memmap) assert isinstance(fn(self.filename), np.memmap)
def teardown_method(self): def teardown_method(self):
(pytensor.config.compiledir / "_test.npy").unlink() (pytensor.config.compiledir / "_test.npy").unlink()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论