提交 0145d609 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Changes for deprecations in numpy 2.0 C-API

- replace `->elsize` by `PyArray_ITEMSIZE` - don't use deprecated PyArray_MoveInto
上级 8f4d7b1f
...@@ -3610,7 +3610,7 @@ class StructuredDotGradCSC(COp): ...@@ -3610,7 +3610,7 @@ class StructuredDotGradCSC(COp):
out[0] = g_a_data out[0] = g_a_data
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(_indices, _indptr, _d, _g) = inputs (_indices, _indptr, _d, _g) = inputs
...@@ -3647,11 +3647,11 @@ class StructuredDotGradCSC(COp): ...@@ -3647,11 +3647,11 @@ class StructuredDotGradCSC(COp):
npy_intp nnz = PyArray_DIMS({_indices})[0]; npy_intp nnz = PyArray_DIMS({_indices})[0];
npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this
npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices});
npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr});
const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d});
const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g});
const npy_intp K = PyArray_DIMS({_d})[1]; const npy_intp K = PyArray_DIMS({_d})[1];
...@@ -3744,7 +3744,7 @@ class StructuredDotGradCSR(COp): ...@@ -3744,7 +3744,7 @@ class StructuredDotGradCSR(COp):
out[0] = g_a_data out[0] = g_a_data
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(_indices, _indptr, _d, _g) = inputs (_indices, _indptr, _d, _g) = inputs
...@@ -3782,11 +3782,11 @@ class StructuredDotGradCSR(COp): ...@@ -3782,11 +3782,11 @@ class StructuredDotGradCSR(COp):
// extract number of rows // extract number of rows
npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this
npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices});
npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr});
const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d});
const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g});
const npy_intp K = PyArray_DIMS({_d})[1]; const npy_intp K = PyArray_DIMS({_d})[1];
......
...@@ -498,7 +498,7 @@ class GemmRelated(COp): ...@@ -498,7 +498,7 @@ class GemmRelated(COp):
int unit = 0; int unit = 0;
int type_num = PyArray_DESCR(%(_x)s)->type_num; int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes
npy_intp* Nx = PyArray_DIMS(%(_x)s); npy_intp* Nx = PyArray_DIMS(%(_x)s);
npy_intp* Ny = PyArray_DIMS(%(_y)s); npy_intp* Ny = PyArray_DIMS(%(_y)s);
...@@ -789,7 +789,7 @@ class GemmRelated(COp): ...@@ -789,7 +789,7 @@ class GemmRelated(COp):
) )
def build_gemm_version(self): def build_gemm_version(self):
return (13, blas_header_version()) return (14, blas_header_version())
class Gemm(GemmRelated): class Gemm(GemmRelated):
...@@ -1030,7 +1030,7 @@ class Gemm(GemmRelated): ...@@ -1030,7 +1030,7 @@ class Gemm(GemmRelated):
%(fail)s %(fail)s
} }
if(PyArray_MoveInto(x_new, %(_x)s) == -1) if(PyArray_CopyInto(x_new, %(_x)s) == -1)
{ {
%(fail)s %(fail)s
} }
...@@ -1056,7 +1056,7 @@ class Gemm(GemmRelated): ...@@ -1056,7 +1056,7 @@ class Gemm(GemmRelated):
%(fail)s %(fail)s
} }
if(PyArray_MoveInto(y_new, %(_y)s) == -1) if(PyArray_CopyInto(y_new, %(_y)s) == -1)
{ {
%(fail)s %(fail)s
} }
...@@ -1102,7 +1102,7 @@ class Gemm(GemmRelated): ...@@ -1102,7 +1102,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self): def c_code_cache_version(self):
gv = self.build_gemm_version() gv = self.build_gemm_version()
if gv: if gv:
return (7, *gv) return (8, *gv)
else: else:
return gv return gv
...@@ -1538,7 +1538,7 @@ class BatchedDot(COp): ...@@ -1538,7 +1538,7 @@ class BatchedDot(COp):
return f""" return f"""
int type_num = PyArray_DESCR({_x})->type_num; int type_num = PyArray_DESCR({_x})->type_num;
int type_size = PyArray_DESCR({_x})->elsize; // in bytes int type_size = PyArray_ITEMSIZE({_x}); // in bytes
if (PyArray_NDIM({_x}) != 3) {{ if (PyArray_NDIM({_x}) != 3) {{
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
...@@ -1598,7 +1598,7 @@ class BatchedDot(COp): ...@@ -1598,7 +1598,7 @@ class BatchedDot(COp):
def c_code_cache_version(self): def c_code_cache_version(self):
from pytensor.tensor.blas_headers import blas_header_version from pytensor.tensor.blas_headers import blas_header_version
return (5, blas_header_version()) return (6, blas_header_version())
def grad(self, inp, grads): def grad(self, inp, grads):
x, y = inp x, y = inp
......
...@@ -1053,7 +1053,7 @@ def openblas_threads_text(): ...@@ -1053,7 +1053,7 @@ def openblas_threads_text():
def blas_header_version(): def blas_header_version():
# Version for the base header # Version for the base header
version = (9,) version = (10,)
if detect_macos_sdot_bug(): if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works: if detect_macos_sdot_bug.fix_works:
# Version with fix # Version with fix
...@@ -1071,7 +1071,7 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1071,7 +1071,7 @@ def ____gemm_code(check_ab, a_init, b_init):
const char * error_string = NULL; const char * error_string = NULL;
int type_num = PyArray_DESCR(_x)->type_num; int type_num = PyArray_DESCR(_x)->type_num;
int type_size = PyArray_DESCR(_x)->elsize; // in bytes int type_size = PyArray_ITEMSIZE(_x); // in bytes
npy_intp* Nx = PyArray_DIMS(_x); npy_intp* Nx = PyArray_DIMS(_x);
npy_intp* Ny = PyArray_DIMS(_y); npy_intp* Ny = PyArray_DIMS(_y);
......
...@@ -146,7 +146,7 @@ class WeirdBrokenOp(COp): ...@@ -146,7 +146,7 @@ class WeirdBrokenOp(COp):
raise ValueError(self.behaviour) raise ValueError(self.behaviour)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
(a,) = inp (a,) = inp
...@@ -165,8 +165,8 @@ class WeirdBrokenOp(COp): ...@@ -165,8 +165,8 @@ class WeirdBrokenOp(COp):
prep_vars = f""" prep_vars = f"""
//the output array has size M x N //the output array has size M x N
npy_intp M = PyArray_DIMS({a})[0]; npy_intp M = PyArray_DIMS({a})[0];
npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_DESCR({a})->elsize; npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_ITEMSIZE({a});
npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z});
npy_double * Da = (npy_double*)PyArray_BYTES({a}); npy_double * Da = (npy_double*)PyArray_BYTES({a});
npy_double * Dz = (npy_double*)PyArray_BYTES({z}); npy_double * Dz = (npy_double*)PyArray_BYTES({z});
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论