提交 340c577a authored 作者: notoraptor's avatar notoraptor

Add fallback implementation for BLAS [sd]gemv_.

上级 d66cefe7
...@@ -731,29 +731,28 @@ def cblas_header_text(): ...@@ -731,29 +731,28 @@ def cblas_header_text():
def blas_header_text(): def blas_header_text():
"""C header for the fortran blas interface""" """C header for the fortran blas interface"""
gemm_code = "" blas_code = ""
const = "const" const = "const"
if not config.blas.ldflags: if not config.blas.ldflags:
# Include the Numpy version implementation of [sd]gemm_. # Include the Numpy version implementation of [sd]gemm_.
current_filedir = dirname(__file__) current_filedir = dirname(__file__)
gemm_common_filepath = os.path.join(current_filedir, 'c_code', 'alt_gemm_common.c') blas_common_filepath = os.path.join(current_filedir, 'c_code', 'alt_blas_common.h')
gemm_template_filepath = os.path.join(current_filedir, 'c_code', 'alt_gemm_template.c') blas_template_filepath = os.path.join(current_filedir, 'c_code', 'alt_blas_template.c')
common_code = "" common_code = ""
sgemm_code = "" sblas_code = ""
dgemm_code = "" dblas_code = ""
with open(gemm_common_filepath) as code: with open(blas_common_filepath) as code:
common_code = code.read() common_code = code.read()
with open(gemm_template_filepath) as code: with open(blas_template_filepath) as code:
template_code = code.read() template_code = code.read()
sgemm_code = template_code % {"float_type": "float", "float_size": 4, "npy_float": "NPY_FLOAT32", "name": "sgemm_"} sblas_code = template_code % {"float_type": "float", "float_size": 4, "npy_float": "NPY_FLOAT32", "precision": "s"}
dgemm_code = template_code % {"float_type": "double", "float_size": 8, "npy_float": "NPY_FLOAT64", "name": "dgemm_"} dblas_code = template_code % {"float_type": "double", "float_size": 8, "npy_float": "NPY_FLOAT64", "precision": "d"}
if not common_code or not sgemm_code: if not common_code or not template_code:
raise IOError("Unable to load NumPy implementation of gemm code from C source files.") raise IOError("Unable to load NumPy implementation of BLAS functions from C source files.")
else:
const = "" const = ""
gemm_code += common_code blas_code += common_code
gemm_code += sgemm_code blas_code += sblas_code
gemm_code += dgemm_code blas_code += dblas_code
header = """ header = """
extern "C" extern "C"
...@@ -834,7 +833,7 @@ def blas_header_text(): ...@@ -834,7 +833,7 @@ def blas_header_text():
/* Single Precision */ /* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void sgemv_(char*, const int*, const int*, const float *, %(const)s float *, const int*, %(const)s float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
...@@ -984,7 +983,7 @@ def blas_header_text(): ...@@ -984,7 +983,7 @@ def blas_header_text():
} }
""") """)
return (header % {'const': const}) + gemm_code return (header % {'const': const}) + blas_code
def mkl_threads_text(): def mkl_threads_text():
...@@ -1032,7 +1031,7 @@ def openblas_threads_text(): ...@@ -1032,7 +1031,7 @@ def openblas_threads_text():
def blas_header_version(): def blas_header_version():
# Version for the base header # Version for the base header
version = (2,) version = (3,)
if detect_macos_sdot_bug(): if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works: if detect_macos_sdot_bug.fix_works:
# Version with fix # Version with fix
......
/** C Implementation (with NumPy back-end) of BLAS functions used in Theano.
* Used instead of BLAS when Theano flag ``blas.ldflags`` is empty.
* This file contains some useful header code not templated.
* File alt_blas_template.c contains template code for [sd]gemm_ and [sd]gemv_. **/
#define alt_fatal_error(message) { if (PyErr_Occurred()) PyErr_Print(); if(message != NULL) fprintf(stderr, message); exit(-1); }
#define alt_trans_to_bool(trans) (*trans != 'N' && *trans != 'n')
/**Template code for BLAS functions follows in file alt_blas_template.c
* (as Python string to be used with old formatting).
* PARAMETERS:
* float_type: "float" or "double".
* float_size: 4 for float32 (sgemm_), 8 for float64 (dgemm_).
* npy_float: "NPY_FLOAT32" or "NPY_FLOAT64".
* precision: "s" for single, "d" for double.
* See blas_headers.py for current use.**/
/** %(name)s **/ /** Alternative template NumPy-based implementation of BLAS functions used in Theano. **/
/* Scalar*Matrix function. /* Scalar * Matrix function.
* Computes: matrix = scalar*matrix. */ * Computes: matrix = scalar * matrix. */
void alt_numpy_scale_matrix_inplace_%(float_type)s(const %(float_type)s* scalar, PyArrayObject* matrix) { void alt_numpy_scale_matrix_inplace_%(float_type)s(const %(float_type)s* scalar, PyArrayObject* matrix) {
NpyIter* iterator = NpyIter_New(matrix, NpyIter* iterator = NpyIter_New(matrix,
NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK, NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK,
...@@ -25,7 +25,8 @@ void alt_numpy_scale_matrix_inplace_%(float_type)s(const %(float_type)s* scalar, ...@@ -25,7 +25,8 @@ void alt_numpy_scale_matrix_inplace_%(float_type)s(const %(float_type)s* scalar,
} while(get_next(iterator)); } while(get_next(iterator));
NpyIter_Deallocate(iterator); NpyIter_Deallocate(iterator);
} }
/* Matrix+Matrix function.
/* Matrix + Matrix function.
* Computes: matrix2 = (scalar1 * matrix1) + (scalar2 * matrix2) */ * Computes: matrix2 = (scalar1 * matrix1) + (scalar2 * matrix2) */
void alt_numpy_matrix_extended_sum_inplace_%(float_type)s( void alt_numpy_matrix_extended_sum_inplace_%(float_type)s(
const %(float_type)s* scalar1, PyArrayObject* matrix1, const %(float_type)s* scalar1, PyArrayObject* matrix1,
...@@ -48,11 +49,12 @@ void alt_numpy_matrix_extended_sum_inplace_%(float_type)s( ...@@ -48,11 +49,12 @@ void alt_numpy_matrix_extended_sum_inplace_%(float_type)s(
} while(get_next(iterators)); } while(get_next(iterators));
NpyIter_Deallocate(iterators); NpyIter_Deallocate(iterators);
} }
/* NumPy Wrapping function. Wraps a data into a NumPy's PyArrayObject. /* NumPy Wrapping function. Wraps a data into a NumPy's PyArrayObject.
* By default, data is considered as Fortran-style array (column by column). * By default, data is considered as Fortran-style array (column by column).
* If to_transpose, data will be considered as C-style array (row by row) * If to_transpose, data will be considered as C-style array (row by row)
* with dimensions reversed. */ * with dimensions reversed. */
PyObject* alt_op_%(float_type)s(int to_transpose, %(float_type)s* M, int nrow, int ncol, int LDM) { PyObject* alt_op_%(float_type)s(int to_transpose, %(float_type)s* M, int nrow, int ncol, int LDM, int numpyFlags) {
npy_intp dims[2]; npy_intp dims[2];
npy_intp strides[2]; npy_intp strides[2];
if(to_transpose) { if(to_transpose) {
...@@ -66,9 +68,10 @@ PyObject* alt_op_%(float_type)s(int to_transpose, %(float_type)s* M, int nrow, i ...@@ -66,9 +68,10 @@ PyObject* alt_op_%(float_type)s(int to_transpose, %(float_type)s* M, int nrow, i
strides[0] = %(float_size)d; strides[0] = %(float_size)d;
strides[1] = LDM * %(float_size)d; strides[1] = LDM * %(float_size)d;
} }
return PyArray_New(&PyArray_Type, 2, dims, %(npy_float)s, strides, M, 0, 0, NULL); return PyArray_New(&PyArray_Type, 2, dims, %(npy_float)s, strides, M, 0, numpyFlags, NULL);
} }
/* Special wrapping case used for matrix C in gemm implementation. */
/* Special wrapping case used for matrix C in gemm_ implementation. */
inline PyObject* alt_wrap_fortran_writeable_matrix_%(float_type)s( inline PyObject* alt_wrap_fortran_writeable_matrix_%(float_type)s(
%(float_type)s* matrix, const int* nrow, const int* ncol, const int* LD %(float_type)s* matrix, const int* nrow, const int* ncol, const int* LD
) { ) {
...@@ -76,15 +79,16 @@ inline PyObject* alt_wrap_fortran_writeable_matrix_%(float_type)s( ...@@ -76,15 +79,16 @@ inline PyObject* alt_wrap_fortran_writeable_matrix_%(float_type)s(
npy_intp strides[2] = {%(float_size)d, (*LD) * %(float_size)d}; npy_intp strides[2] = {%(float_size)d, (*LD) * %(float_size)d};
return PyArray_New(&PyArray_Type, 2, dims, %(npy_float)s, strides, matrix, 0, NPY_ARRAY_WRITEABLE, NULL); return PyArray_New(&PyArray_Type, 2, dims, %(npy_float)s, strides, matrix, 0, NPY_ARRAY_WRITEABLE, NULL);
} }
/* %(name)s template code */
void %(name)s( /* gemm_ template code */
void %(precision)sgemm_(
char* TRANSA, char* TRANSB, const int* M, const int* N, const int* K, char* TRANSA, char* TRANSB, const int* M, const int* N, const int* K,
const %(float_type)s* ALPHA, %(float_type)s* A, const int* LDA, const %(float_type)s* ALPHA, %(float_type)s* A, const int* LDA,
%(float_type)s* B, const int* LDB, const %(float_type)s* BETA, %(float_type)s* B, const int* LDB, const %(float_type)s* BETA,
%(float_type)s* C, const int* LDC %(float_type)s* C, const int* LDC
) { ) {
if(*M < 0 || *N < 0 || *K < 0 || *LDA < 0 || *LDB < 0 || *LDC < 0) if(*M < 0 || *N < 0 || *K < 0 || *LDA < 0 || *LDB < 0 || *LDC < 0)
alt_fatal_error("The integer arguments passed to %(name)s must all be at least 0."); alt_fatal_error("The integer arguments passed to %(precision)sgemm_ must all be at least 0.");
/* If M or N is null, there is nothing to do with C, /* If M or N is null, there is nothing to do with C,
* as C should contain M*N == 0 items. */ * as C should contain M*N == 0 items. */
if(*M == 0 || *N == 0) if(*M == 0 || *N == 0)
...@@ -140,9 +144,11 @@ void %(name)s( ...@@ -140,9 +144,11 @@ void %(name)s(
* consider the buffer as a F-contiguous M*N matrix, so that * consider the buffer as a F-contiguous M*N matrix, so that
* it will get the transposed of op_B_transposed * op_A_transposed, * it will get the transposed of op_B_transposed * op_A_transposed,
* that is op_A * op_B (M*N matrix) as expected. */ * that is op_A * op_B (M*N matrix) as expected. */
PyObject* opA_transposed = alt_op_%(float_type)s(!to_transpose_A, A, nrowa, ncola, *LDA); PyObject* opA_transposed = alt_op_%(float_type)s(!to_transpose_A, A, nrowa, ncola, *LDA, 0);
PyObject* opB_transposed = alt_op_%(float_type)s(!to_transpose_B, B, nrowb, ncolb, *LDB); PyObject* opB_transposed = alt_op_%(float_type)s(!to_transpose_B, B, nrowb, ncolb, *LDB, 0);
PyObject* opB_trans_dot_opA_trans = PyArray_New(&PyArray_Type, 2, computation_dims, %(npy_float)s, computation_strides, computation_pointer, 0, computation_flags, NULL); PyObject* opB_trans_dot_opA_trans = PyArray_New(&PyArray_Type, 2, computation_dims, %(npy_float)s,
computation_strides, computation_pointer, 0,
computation_flags, NULL);
PyArray_MatrixProduct2(opB_transposed, opA_transposed, (PyArrayObject*)opB_trans_dot_opA_trans); PyArray_MatrixProduct2(opB_transposed, opA_transposed, (PyArrayObject*)opB_trans_dot_opA_trans);
/* PyArray_MatrixProduct2 adds a reference to the output array, /* PyArray_MatrixProduct2 adds a reference to the output array,
* which we need to remove to avoid a memory leak. */ * which we need to remove to avoid a memory leak. */
...@@ -156,7 +162,7 @@ void %(name)s( ...@@ -156,7 +162,7 @@ void %(name)s(
PyObject* matrix_C = alt_wrap_fortran_writeable_matrix_%(float_type)s(C, M, N, LDC); PyObject* matrix_C = alt_wrap_fortran_writeable_matrix_%(float_type)s(C, M, N, LDC);
PyObject* alpha_opA_dot_opB = PyArray_Transpose((PyArrayObject*)opB_trans_dot_opA_trans, NULL); PyObject* alpha_opA_dot_opB = PyArray_Transpose((PyArrayObject*)opB_trans_dot_opA_trans, NULL);
if(0 != PyArray_CopyInto((PyArrayObject*)matrix_C, (PyArrayObject*)alpha_opA_dot_opB)) if(0 != PyArray_CopyInto((PyArrayObject*)matrix_C, (PyArrayObject*)alpha_opA_dot_opB))
alt_fatal_error("NumPy %(name)s implementation: unable to copy ALPHA*op(A)*op(B) into C when BETA == 0."); alt_fatal_error("NumPy %(precision)sgemm_ implementation: unable to copy ALPHA*op(A)*op(B) into C when BETA == 0.");
Py_XDECREF(alpha_opA_dot_opB); Py_XDECREF(alpha_opA_dot_opB);
Py_XDECREF(matrix_C); Py_XDECREF(matrix_C);
} }
...@@ -164,7 +170,8 @@ void %(name)s( ...@@ -164,7 +170,8 @@ void %(name)s(
/* C is read, so we must consider it as Fortran-style matrix. */ /* C is read, so we must consider it as Fortran-style matrix. */
PyObject* matrix_C = alt_wrap_fortran_writeable_matrix_%(float_type)s(C, M, N, LDC); PyObject* matrix_C = alt_wrap_fortran_writeable_matrix_%(float_type)s(C, M, N, LDC);
PyObject* opA_dot_opB = PyArray_Transpose((PyArrayObject*)opB_trans_dot_opA_trans, NULL); PyObject* opA_dot_opB = PyArray_Transpose((PyArrayObject*)opB_trans_dot_opA_trans, NULL);
alt_numpy_matrix_extended_sum_inplace_%(float_type)s(ALPHA, (PyArrayObject*)opA_dot_opB, BETA, (PyArrayObject*)matrix_C); alt_numpy_matrix_extended_sum_inplace_%(float_type)s(ALPHA, (PyArrayObject*)opA_dot_opB,
BETA, (PyArrayObject*)matrix_C);
Py_XDECREF(opA_dot_opB); Py_XDECREF(opA_dot_opB);
Py_XDECREF(matrix_C); Py_XDECREF(matrix_C);
} }
...@@ -172,3 +179,81 @@ void %(name)s( ...@@ -172,3 +179,81 @@ void %(name)s(
Py_XDECREF(opB_transposed); Py_XDECREF(opB_transposed);
Py_XDECREF(opA_transposed); Py_XDECREF(opA_transposed);
} }
/* gemv */
void %(precision)sgemv_(
char* TRANS,
const int* M,
const int* N,
const %(float_type)s* ALPHA,
%(float_type)s* A,
const int* LDA,
%(float_type)s* x,
const int* incx,
const %(float_type)s* BETA,
%(float_type)s* y,
const int* incy
) {
/**
If TRANS is 'n' or 'N', computes:
y = ALPHA * A * x + BETA * y
Else, computes:
y = ALPHA * A.T * x + BETA * y
A is a M*N matrix, A.T is A transposed
x, y are vectors
ALPHA, BETA are scalars
**/
if (*M < 0 || *N < 0 || *LDA < 0 | *incx < 0 || *incy < 0)
alt_fatal_error("The integer arguments passed to %(precision)sgemv_ must all be at least 0.");
int transpose = alt_trans_to_bool(TRANS);
if (*M == 0 || *N == 0) {
/* A contains M * N == 0 values. y should be empty too, and we have nothing to do. */
if ((transpose && *N != 0) || (!transpose && *M != 0))
alt_fatal_error("NumPy %(precision)sgemv_ implementation: the output vector should be empty.");
return;
}
PyObject* matrixA = alt_op_%(float_type)s(transpose, A, *M, *N, *LDA, 0);
PyObject* matrixX = NULL;
PyObject* matrixY = NULL;
if (transpose) {
matrixX = alt_op_%(float_type)s(1, x, 1, *M, *incx, 0);
matrixY = alt_op_%(float_type)s(1, y, 1, *N, *incy, NPY_ARRAY_WRITEABLE);
} else {
matrixX = alt_op_%(float_type)s(1, x, 1, *N, *incx, 0);
matrixY = alt_op_%(float_type)s(1, y, 1, *M, *incy, NPY_ARRAY_WRITEABLE);
};
if (*ALPHA == 0) {
// Just BETA * y
alt_numpy_scale_matrix_inplace_%(float_type)s(BETA, (PyArrayObject*)matrixY);
} else if (*BETA == 0) {
// We can directly compute alpha * A * x into y if y is C-contiguous.
if (PyArray_IS_C_CONTIGUOUS((PyArrayObject*)matrixY)) {
PyArray_MatrixProduct2(matrixA, matrixX, (PyArrayObject*)matrixY);
// PyArray_MatrixProduct2 adds an extra reference to the output array.
Py_XDECREF(matrixY);
alt_numpy_scale_matrix_inplace_%(float_type)s(ALPHA, (PyArrayObject*)matrixY);
} else {
// If y is not contiguous, we need a temporar workspace.
PyObject* tempAX = PyArray_MatrixProduct(matrixA, matrixX);
if (tempAX == NULL)
alt_fatal_error("NumPy %(precision)sgemv_ implementation: Unable to get matrix product.");
alt_numpy_scale_matrix_inplace_%(float_type)s(ALPHA, (PyArrayObject*)tempAX);
if(0 != PyArray_CopyInto((PyArrayObject*)matrixY, (PyArrayObject*)tempAX)) {
alt_fatal_error("NumPy %(precision)sgemv_ implementation: unable to update output.");
}
Py_XDECREF(tempAX);
}
} else {
// We must perform full computation.
PyObject* tempAX = PyArray_MatrixProduct(matrixA, matrixX);
if (tempAX == NULL)
alt_fatal_error("NumPy %(precision)sgemv_ implementation: unable to get matrix product.");
// ALPHA * (A * x) + BETA * y.
alt_numpy_matrix_extended_sum_inplace_%(float_type)s(ALPHA, (PyArrayObject*)tempAX,
BETA, (PyArrayObject*)matrixY);
Py_XDECREF(tempAX);
}
Py_XDECREF(matrixY);
Py_XDECREF(matrixX);
Py_XDECREF(matrixA);
}
\ No newline at end of file
/** C Implementation of [sd]gemm_ based on NumPy
* Used instead of blas when Theano config flag blas.ldflags is empty.
* This file contains the common code for [sd]gemm_.
* File alt_gemm_template.c contains template code for [sd]gemm_. **/
#define alt_fatal_error(message) { if(message != NULL) fprintf(stderr, message); exit(-1); }
#define alt_trans_to_bool(trans) (*trans != 'N' && *trans != 'n')
/**Template code for [sd]gemm_ follows in file alt_gemm_template.c
* (as Python string to be used with old formatting).
* PARAMETERS:
* float_type: "float" for sgemm_, "double" for dgemm_.
* float_size: 4 for float32 (sgemm_), 8 for float64 (dgemm_).
* npy_float: "NPY_FLOAT32" for sgemm_, "NPY_FLOAT64" for dgemm_.
* name: "sgemm_" for sgemm_, "dgemm_" for dgemm_.
* See blas_headers.py for current use.**/
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论