提交 340c577a authored 作者: notoraptor's avatar notoraptor

Add fallback implementation for BLAS [sd]gemv_.

上级 d66cefe7
......@@ -731,29 +731,28 @@ def cblas_header_text():
def blas_header_text():
"""C header for the fortran blas interface"""
gemm_code = ""
blas_code = ""
const = "const"
if not config.blas.ldflags:
# Include the Numpy version implementation of [sd]gemm_.
current_filedir = dirname(__file__)
gemm_common_filepath = os.path.join(current_filedir, 'c_code', 'alt_gemm_common.c')
gemm_template_filepath = os.path.join(current_filedir, 'c_code', 'alt_gemm_template.c')
blas_common_filepath = os.path.join(current_filedir, 'c_code', 'alt_blas_common.h')
blas_template_filepath = os.path.join(current_filedir, 'c_code', 'alt_blas_template.c')
common_code = ""
sgemm_code = ""
dgemm_code = ""
with open(gemm_common_filepath) as code:
sblas_code = ""
dblas_code = ""
with open(blas_common_filepath) as code:
common_code = code.read()
with open(gemm_template_filepath) as code:
with open(blas_template_filepath) as code:
template_code = code.read()
sgemm_code = template_code % {"float_type": "float", "float_size": 4, "npy_float": "NPY_FLOAT32", "name": "sgemm_"}
dgemm_code = template_code % {"float_type": "double", "float_size": 8, "npy_float": "NPY_FLOAT64", "name": "dgemm_"}
if not common_code or not sgemm_code:
raise IOError("Unable to load NumPy implementation of gemm code from C source files.")
else:
const = ""
gemm_code += common_code
gemm_code += sgemm_code
gemm_code += dgemm_code
sblas_code = template_code % {"float_type": "float", "float_size": 4, "npy_float": "NPY_FLOAT32", "precision": "s"}
dblas_code = template_code % {"float_type": "double", "float_size": 8, "npy_float": "NPY_FLOAT64", "precision": "d"}
if not common_code or not template_code:
raise IOError("Unable to load NumPy implementation of BLAS functions from C source files.")
const = ""
blas_code += common_code
blas_code += sblas_code
blas_code += dblas_code
header = """
extern "C"
......@@ -834,7 +833,7 @@ def blas_header_text():
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgemv_(char*, const int*, const int*, const float *, %(const)s float *, const int*, %(const)s float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
......@@ -984,7 +983,7 @@ def blas_header_text():
}
""")
return (header % {'const': const}) + gemm_code
return (header % {'const': const}) + blas_code
def mkl_threads_text():
......@@ -1032,7 +1031,7 @@ def openblas_threads_text():
def blas_header_version():
# Version for the base header
version = (2,)
version = (3,)
if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works:
# Version with fix
......
/** C Implementation (with NumPy back-end) of BLAS functions used in Theano.
* Used instead of BLAS when Theano flag ``blas.ldflags`` is empty.
* This file contains some useful header code not templated.
* File alt_blas_template.c contains template code for [sd]gemm_ and [sd]gemv_. **/
#define alt_fatal_error(message) { if (PyErr_Occurred()) PyErr_Print(); if(message != NULL) fprintf(stderr, message); exit(-1); }
#define alt_trans_to_bool(trans) (*trans != 'N' && *trans != 'n')
/**Template code for BLAS functions follows in file alt_blas_template.c
* (as Python string to be used with old formatting).
* PARAMETERS:
* float_type: "float" or "double".
* float_size: 4 for float32 (sgemm_), 8 for float64 (dgemm_).
* npy_float: "NPY_FLOAT32" or "NPY_FLOAT64".
* precision: "s" for single, "d" for double.
* See blas_headers.py for current use.**/
/** C Implementation of [sd]gemm_ based on NumPy
* Used instead of blas when Theano config flag blas.ldflags is empty.
* This file contains the common code for [sd]gemm_.
* File alt_gemm_template.c contains template code for [sd]gemm_. **/
#define alt_fatal_error(message) { if(message != NULL) fprintf(stderr, message); exit(-1); }
#define alt_trans_to_bool(trans) (*trans != 'N' && *trans != 'n')
/**Template code for [sd]gemm_ follows in file alt_gemm_template.c
* (as Python string to be used with old formatting).
* PARAMETERS:
* float_type: "float" for sgemm_, "double" for dgemm_.
* float_size: 4 for float32 (sgemm_), 8 for float64 (dgemm_).
* npy_float: "NPY_FLOAT32" for sgemm_, "NPY_FLOAT64" for dgemm_.
* name: "sgemm_" for sgemm_, "dgemm_" for dgemm_.
* See blas_headers.py for current use.**/
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论