提交 c5b67f2d authored 作者: notoraptor's avatar notoraptor

Remove parametrization of `const` into blas header.

上级 f00420e6
...@@ -732,7 +732,6 @@ def blas_header_text(): ...@@ -732,7 +732,6 @@ def blas_header_text():
"""C header for the fortran blas interface""" """C header for the fortran blas interface"""
blas_code = "" blas_code = ""
const = "const"
if not config.blas.ldflags: if not config.blas.ldflags:
# Include the Numpy version implementation of [sd]gemm_. # Include the Numpy version implementation of [sd]gemm_.
current_filedir = dirname(__file__) current_filedir = dirname(__file__)
...@@ -749,7 +748,6 @@ def blas_header_text(): ...@@ -749,7 +748,6 @@ def blas_header_text():
dblas_code = template_code % {"float_type": "double", "float_size": 8, "npy_float": "NPY_FLOAT64", "precision": "d"} dblas_code = template_code % {"float_type": "double", "float_size": 8, "npy_float": "NPY_FLOAT64", "precision": "d"}
if not common_code or not template_code: if not common_code or not template_code:
raise IOError("Unable to load NumPy implementation of BLAS functions from C source files.") raise IOError("Unable to load NumPy implementation of BLAS functions from C source files.")
const = ""
blas_code += common_code blas_code += common_code
blas_code += sblas_code blas_code += sblas_code
blas_code += dblas_code blas_code += dblas_code
...@@ -833,7 +831,7 @@ def blas_header_text(): ...@@ -833,7 +831,7 @@ def blas_header_text():
/* Single Precision */ /* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, %(const)s float *, const int*, %(const)s float *, const int*, const float *, float *, const int*); void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
...@@ -915,7 +913,7 @@ def blas_header_text(): ...@@ -915,7 +913,7 @@ def blas_header_text():
/* Single Precision */ /* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, %(const)s float *, const int*, %(const)s float *, const int*, const float *, float *, const int*); void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*); void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
...@@ -924,7 +922,7 @@ def blas_header_text(): ...@@ -924,7 +922,7 @@ def blas_header_text():
/* Double Precision */ /* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, %(const)s double *, const int*, %(const)s double *, const int*, const double *, double *, const int*); void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*); void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
...@@ -983,7 +981,7 @@ def blas_header_text(): ...@@ -983,7 +981,7 @@ def blas_header_text():
} }
""") """)
return (header % {'const': const}) + blas_code return header + blas_code
if not config.blas.ldflags: if not config.blas.ldflags:
...@@ -1035,7 +1033,7 @@ def openblas_threads_text(): ...@@ -1035,7 +1033,7 @@ def openblas_threads_text():
def blas_header_version(): def blas_header_version():
# Version for the base header # Version for the base header
version = (3,) version = (4,)
if detect_macos_sdot_bug(): if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works: if detect_macos_sdot_bug.fix_works:
# Version with fix # Version with fix
......
...@@ -276,12 +276,12 @@ void %(precision)sgemv_( ...@@ -276,12 +276,12 @@ void %(precision)sgemv_(
PyObject* vector_y = alt_op_%(float_type)s(1, SY, 1, *N, *INCY, 0); PyObject* vector_y = alt_op_%(float_type)s(1, SY, 1, *N, *INCY, 0);
if (vector_x == NULL || vector_y == NULL) if (vector_x == NULL || vector_y == NULL)
alt_fatal_error("NumPy %(precision)sdot_: unables to wrap x and y arrays."); alt_fatal_error("NumPy %(precision)sdot_: unable to wrap x and y arrays.");
// Make matrix product: (1, N) * (N, 1) => (1, 1) // Make matrix product: (1, N) * (N, 1) => (1, 1)
PyObject* dot_product = PyArray_MatrixProduct(vector_x, vector_y); PyObject* dot_product = PyArray_MatrixProduct(vector_x, vector_y);
if (dot_product == NULL) if (dot_product == NULL)
alt_fatal_error("NumPy %(precision)sdot_: unables to compute dot."); alt_fatal_error("NumPy %(precision)sdot_: unable to compute dot.");
// Get result. // Get result.
%(float_type)s result = *(%(float_type)s*)PyArray_DATA((PyArrayObject*)dot_product); %(float_type)s result = *(%(float_type)s*)PyArray_DATA((PyArrayObject*)dot_product);
Py_XDECREF(dot_product); Py_XDECREF(dot_product);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论