提交 0145d609 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Changes for deprecations in numpy 2.0 C-API

- replace `->elsize` by `PyArray_ITEMSIZE` - don't use deprecated PyArray_MoveInto
上级 8f4d7b1f
......@@ -3610,7 +3610,7 @@ class StructuredDotGradCSC(COp):
out[0] = g_a_data
def c_code_cache_version(self):
return (1,)
return (2,)
def c_code(self, node, name, inputs, outputs, sub):
(_indices, _indptr, _d, _g) = inputs
......@@ -3647,11 +3647,11 @@ class StructuredDotGradCSC(COp):
npy_intp nnz = PyArray_DIMS({_indices})[0];
npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this
npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize;
npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize;
npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices});
npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr});
const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize;
const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize;
const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d});
const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g});
const npy_intp K = PyArray_DIMS({_d})[1];
......@@ -3744,7 +3744,7 @@ class StructuredDotGradCSR(COp):
out[0] = g_a_data
def c_code_cache_version(self):
return (1,)
return (2,)
def c_code(self, node, name, inputs, outputs, sub):
(_indices, _indptr, _d, _g) = inputs
......@@ -3782,11 +3782,11 @@ class StructuredDotGradCSR(COp):
// extract number of rows
npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this
npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize;
npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize;
npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices});
npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr});
const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize;
const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize;
const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d});
const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g});
const npy_intp K = PyArray_DIMS({_d})[1];
......
......@@ -498,7 +498,7 @@ class GemmRelated(COp):
int unit = 0;
int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes
npy_intp* Nx = PyArray_DIMS(%(_x)s);
npy_intp* Ny = PyArray_DIMS(%(_y)s);
......@@ -789,7 +789,7 @@ class GemmRelated(COp):
)
def build_gemm_version(self):
return (13, blas_header_version())
return (14, blas_header_version())
class Gemm(GemmRelated):
......@@ -1030,7 +1030,7 @@ class Gemm(GemmRelated):
%(fail)s
}
if(PyArray_MoveInto(x_new, %(_x)s) == -1)
if(PyArray_CopyInto(x_new, %(_x)s) == -1)
{
%(fail)s
}
......@@ -1056,7 +1056,7 @@ class Gemm(GemmRelated):
%(fail)s
}
if(PyArray_MoveInto(y_new, %(_y)s) == -1)
if(PyArray_CopyInto(y_new, %(_y)s) == -1)
{
%(fail)s
}
......@@ -1102,7 +1102,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self):
gv = self.build_gemm_version()
if gv:
return (7, *gv)
return (8, *gv)
else:
return gv
......@@ -1538,7 +1538,7 @@ class BatchedDot(COp):
return f"""
int type_num = PyArray_DESCR({_x})->type_num;
int type_size = PyArray_DESCR({_x})->elsize; // in bytes
int type_size = PyArray_ITEMSIZE({_x}); // in bytes
if (PyArray_NDIM({_x}) != 3) {{
PyErr_Format(PyExc_NotImplementedError,
......@@ -1598,7 +1598,7 @@ class BatchedDot(COp):
def c_code_cache_version(self):
from pytensor.tensor.blas_headers import blas_header_version
return (5, blas_header_version())
return (6, blas_header_version())
def grad(self, inp, grads):
x, y = inp
......
......@@ -1053,7 +1053,7 @@ def openblas_threads_text():
def blas_header_version():
# Version for the base header
version = (9,)
version = (10,)
if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works:
# Version with fix
......@@ -1071,7 +1071,7 @@ def ____gemm_code(check_ab, a_init, b_init):
const char * error_string = NULL;
int type_num = PyArray_DESCR(_x)->type_num;
int type_size = PyArray_DESCR(_x)->elsize; // in bytes
int type_size = PyArray_ITEMSIZE(_x); // in bytes
npy_intp* Nx = PyArray_DIMS(_x);
npy_intp* Ny = PyArray_DIMS(_y);
......
......@@ -146,7 +146,7 @@ class WeirdBrokenOp(COp):
raise ValueError(self.behaviour)
def c_code_cache_version(self):
return (1,)
return (2,)
def c_code(self, node, name, inp, out, sub):
(a,) = inp
......@@ -165,8 +165,8 @@ class WeirdBrokenOp(COp):
prep_vars = f"""
//the output array has size M x N
npy_intp M = PyArray_DIMS({a})[0];
npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_DESCR({a})->elsize;
npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize;
npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_ITEMSIZE({a});
npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z});
npy_double * Da = (npy_double*)PyArray_BYTES({a});
npy_double * Dz = (npy_double*)PyArray_BYTES({z});
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论