提交 a1d07eb8 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Remove npy<2 compatibility for NPY_RAVEL_AXIS value

上级 05f69852
...@@ -16,14 +16,6 @@ else: ...@@ -16,14 +16,6 @@ else:
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined] ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
# to patch up some of the C code, we need to use these special values...
if using_numpy_2:
numpy_axis_is_none_flag = np.iinfo(np.int32).min # the value of "NPY_RAVEL_AXIS"
else:
# 32 is the value used to mark axis = None in Numpy C-API prior to version 2.0
numpy_axis_is_none_flag = 32
# function that replicates np.unique from numpy < 2.0 # function that replicates np.unique from numpy < 2.0
def old_np_unique( def old_np_unique(
arr, return_index=False, return_inverse=False, return_counts=False, axis=None arr, return_index=False, return_inverse=False, return_counts=False, axis=None
......
...@@ -18,11 +18,7 @@ from pytensor.graph.replace import _vectorize_node ...@@ -18,11 +18,7 @@ from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import EnumList, Generic from pytensor.link.c.type import EnumList, Generic
from pytensor.npy_2_compat import ( from pytensor.npy_2_compat import npy_2_compat_header, old_np_unique
npy_2_compat_header,
numpy_axis_is_none_flag,
old_np_unique,
)
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import int64 as int_t from pytensor.scalar import int64 as int_t
from pytensor.scalar import upcast from pytensor.scalar import upcast
...@@ -51,7 +47,7 @@ from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor ...@@ -51,7 +47,7 @@ from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.utils import normalize_reduce_axis
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH from pytensor.utils import LOCAL_BITWIDTH, NPY_RAVEL_AXIS, PYTHON_INT_BITWIDTH
class CpuContiguous(COp): class CpuContiguous(COp):
...@@ -308,7 +304,7 @@ class CumOp(COp): ...@@ -308,7 +304,7 @@ class CumOp(COp):
@property @property
def c_axis(self) -> int: def c_axis(self) -> int:
if self.axis is None: if self.axis is None:
return numpy_axis_is_none_flag return NPY_RAVEL_AXIS
return self.axis return self.axis
def make_node(self, x): def make_node(self, x):
......
...@@ -14,7 +14,7 @@ from pytensor.graph.op import Op ...@@ -14,7 +14,7 @@ from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import npy_2_compat_header, numpy_axis_is_none_flag from pytensor.npy_2_compat import npy_2_compat_header
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar.basic import BinaryScalarOp from pytensor.scalar.basic import BinaryScalarOp
...@@ -162,7 +162,7 @@ class Argmax(COp): ...@@ -162,7 +162,7 @@ class Argmax(COp):
c_axis = np.int64(self.axis[0]) c_axis = np.int64(self.axis[0])
else: else:
# The value here doesn't matter, it won't be used # The value here doesn't matter, it won't be used
c_axis = numpy_axis_is_none_flag c_axis = 0
return self.params_type.get_params(c_axis=c_axis) return self.params_type.get_params(c_axis=c_axis)
def make_node(self, x): def make_node(self, x):
......
...@@ -10,6 +10,8 @@ from collections.abc import Iterable, Sequence ...@@ -10,6 +10,8 @@ from collections.abc import Iterable, Sequence
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import numpy as np
__all__ = [ __all__ = [
"get_unbound_function", "get_unbound_function",
...@@ -19,6 +21,7 @@ __all__ = [ ...@@ -19,6 +21,7 @@ __all__ = [
"output_subprocess_Popen", "output_subprocess_Popen",
"LOCAL_BITWIDTH", "LOCAL_BITWIDTH",
"PYTHON_INT_BITWIDTH", "PYTHON_INT_BITWIDTH",
"NPY_RAVEL_AXIS",
"NoDuplicateOptWarningFilter", "NoDuplicateOptWarningFilter",
] ]
...@@ -46,6 +49,11 @@ Note that it can be different from the size of a memory pointer. ...@@ -46,6 +49,11 @@ Note that it can be different from the size of a memory pointer.
'l' denotes a C long int, and the size is expressed in bytes. 'l' denotes a C long int, and the size is expressed in bytes.
""" """
NPY_RAVEL_AXIS = np.iinfo(np.int32).min
"""
The value of the numpy C API NPY_RAVEL_AXIS.
"""
def __call_excepthooks(type, value, trace): def __call_excepthooks(type, value, trace):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论