提交 340c577a authored 作者: notoraptor's avatar notoraptor

Add fallback implementation for BLAS [sd]gemv_.

上级 d66cefe7
...@@ -731,29 +731,28 @@ def cblas_header_text(): ...@@ -731,29 +731,28 @@ def cblas_header_text():
def blas_header_text(): def blas_header_text():
"""C header for the fortran blas interface""" """C header for the fortran blas interface"""
gemm_code = "" blas_code = ""
const = "const" const = "const"
if not config.blas.ldflags: if not config.blas.ldflags:
# Include the Numpy version implementation of [sd]gemm_. # Include the Numpy version implementation of [sd]gemm_.
current_filedir = dirname(__file__) current_filedir = dirname(__file__)
gemm_common_filepath = os.path.join(current_filedir, 'c_code', 'alt_gemm_common.c') blas_common_filepath = os.path.join(current_filedir, 'c_code', 'alt_blas_common.h')
gemm_template_filepath = os.path.join(current_filedir, 'c_code', 'alt_gemm_template.c') blas_template_filepath = os.path.join(current_filedir, 'c_code', 'alt_blas_template.c')
common_code = "" common_code = ""
sgemm_code = "" sblas_code = ""
dgemm_code = "" dblas_code = ""
with open(gemm_common_filepath) as code: with open(blas_common_filepath) as code:
common_code = code.read() common_code = code.read()
with open(gemm_template_filepath) as code: with open(blas_template_filepath) as code:
template_code = code.read() template_code = code.read()
sgemm_code = template_code % {"float_type": "float", "float_size": 4, "npy_float": "NPY_FLOAT32", "name": "sgemm_"} sblas_code = template_code % {"float_type": "float", "float_size": 4, "npy_float": "NPY_FLOAT32", "precision": "s"}
dgemm_code = template_code % {"float_type": "double", "float_size": 8, "npy_float": "NPY_FLOAT64", "name": "dgemm_"} dblas_code = template_code % {"float_type": "double", "float_size": 8, "npy_float": "NPY_FLOAT64", "precision": "d"}
if not common_code or not sgemm_code: if not common_code or not template_code:
raise IOError("Unable to load NumPy implementation of gemm code from C source files.") raise IOError("Unable to load NumPy implementation of BLAS functions from C source files.")
else: const = ""
const = "" blas_code += common_code
gemm_code += common_code blas_code += sblas_code
gemm_code += sgemm_code blas_code += dblas_code
gemm_code += dgemm_code
header = """ header = """
extern "C" extern "C"
...@@ -834,7 +833,7 @@ def blas_header_text(): ...@@ -834,7 +833,7 @@ def blas_header_text():
/* Single Precision */ /* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void sgemv_(char*, const int*, const int*, const float *, %(const)s float *, const int*, %(const)s float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
...@@ -984,7 +983,7 @@ def blas_header_text(): ...@@ -984,7 +983,7 @@ def blas_header_text():
} }
""") """)
return (header % {'const': const}) + gemm_code return (header % {'const': const}) + blas_code
def mkl_threads_text(): def mkl_threads_text():
...@@ -1032,7 +1031,7 @@ def openblas_threads_text(): ...@@ -1032,7 +1031,7 @@ def openblas_threads_text():
def blas_header_version(): def blas_header_version():
# Version for the base header # Version for the base header
version = (2,) version = (3,)
if detect_macos_sdot_bug(): if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works: if detect_macos_sdot_bug.fix_works:
# Version with fix # Version with fix
......
/** C Implementation (with NumPy back-end) of BLAS functions used in Theano.
* Used instead of BLAS when Theano flag ``blas.ldflags`` is empty.
* This file contains some useful header code not templated.
* File alt_blas_template.c contains template code for [sd]gemm_ and [sd]gemv_. **/
#define alt_fatal_error(message) { if (PyErr_Occurred()) PyErr_Print(); if(message != NULL) fprintf(stderr, message); exit(-1); }
#define alt_trans_to_bool(trans) (*trans != 'N' && *trans != 'n')
/**Template code for BLAS functions follows in file alt_blas_template.c
* (as Python string to be used with old formatting).
* PARAMETERS:
* float_type: "float" or "double".
* float_size: 4 for float32 (sgemm_), 8 for float64 (dgemm_).
* npy_float: "NPY_FLOAT32" or "NPY_FLOAT64".
* precision: "s" for single, "d" for double.
* See blas_headers.py for current use.**/
/** C Implementation of [sd]gemm_ based on NumPy
* Used instead of blas when Theano config flag blas.ldflags is empty.
* This file contains the common code for [sd]gemm_.
* File alt_gemm_template.c contains template code for [sd]gemm_. **/
#define alt_fatal_error(message) { if(message != NULL) fprintf(stderr, message); exit(-1); }
#define alt_trans_to_bool(trans) (*trans != 'N' && *trans != 'n')
/**Template code for [sd]gemm_ follows in file alt_gemm_template.c
* (as Python string to be used with old formatting).
* PARAMETERS:
* float_type: "float" for sgemm_, "double" for dgemm_.
* float_size: 4 for float32 (sgemm_), 8 for float64 (dgemm_).
* npy_float: "NPY_FLOAT32" for sgemm_, "NPY_FLOAT64" for dgemm_.
* name: "sgemm_" for sgemm_, "dgemm_" for dgemm_.
* See blas_headers.py for current use.**/
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论