提交 0cd7aa7b authored 作者: notoraptor's avatar notoraptor

I added an implementation of C-functions "sgemm_" and "dgemm_" that call Numpy…

I added an implementation of C-functions "sgemm_" and "dgemm_" that call Numpy C-API functions to perform matrix product when BLAS is explicitely disabled (with theano flag "blas.ldflags" set to empty). This can be tested with: THEANO_FLAGS=blas.ldflags= nosetests theano/tensor/nnet/tests/test_abstract_conv.py:TestCorrConv2d
上级 96be471e
/**
C Implementation of dgemm_ based on NumPy
Used instead of blas when Theano config flag blas.ldflags is empty.
**/
void alt_double_scalar_matrix_product_in_place(double scalar, double* matrix, int size_to_compute) {
int i;
for(i = 0; i < size_to_compute; ++i) {
matrix[i] *= scalar;
}
}
void alt_double_matrix_sum_in_place(double* A, double* B, double* out, int size_to_compute) {
int i;
for(i = 0; i < size_to_compute; ++i) {
out[i] = A[i] + B[i];
}
}
/* dgemm
* NB: See sgemm_ for same assumptions.
* */
void dgemm_(char* TRANSA, char* TRANSB,
const int* M, const int* N, const int* K,
const double* ALPHA, double* A, const int* LDA,
double* B, const int* LDB, const double* BETA,
double* C, const int* LDC) {
if(*M < 0 || *N < 0 || *K < 0 || *LDA < 0 || *LDB < 0 || *LDC < 0)
return;
if(C == NULL)
return;
int ka, kb;
if(*TRANSA == 'N' || *TRANSA == 'n')
ka = *K;
else
ka = *M;
if(*TRANSB == 'N' || *TRANSB == 'n')
kb = *N;
else
kb = *K;
npy_intp dims_A[2] = {*LDA, ka};
npy_intp dims_B[2] = {*LDB, kb};
PyObject* matrix_A = PyArray_SimpleNewFromData(2, dims_A, NPY_FLOAT64, A);
PyObject* matrix_B = PyArray_SimpleNewFromData(2, dims_B, NPY_FLOAT64, B);
PyObject* op_A = alt_op(TRANSA, (PyArrayObject*)matrix_A);
PyObject* op_B = alt_op(TRANSB, (PyArrayObject*)matrix_B);
if(*BETA == 0) {
npy_intp dims_C[2] = {*LDC, *N};
PyObject* matrix_C = PyArray_SimpleNewFromData(2, dims_C, NPY_FLOAT64, C);
alt_matrix_matrix_product2(op_A, op_B, matrix_C);
alt_double_scalar_matrix_product_in_place(*ALPHA, C, (*M) * (*N));
Py_XDECREF(matrix_C);
} else {
PyArrayObject* op_A_times_op_B = (PyArrayObject*)alt_matrix_matrix_product(op_A, op_B);
alt_double_scalar_matrix_product_in_place(*ALPHA, (double*)PyArray_DATA(op_A_times_op_B), (*M) * (*N));
alt_double_scalar_matrix_product_in_place(*BETA, C, (*M) * (*N));
alt_double_matrix_sum_in_place((double*)PyArray_DATA(op_A_times_op_B), C, C, (*M) * (*N));
Py_XDECREF(op_A_times_op_B);
}
if(op_B != matrix_B) Py_XDECREF(op_B);
if(op_A != matrix_A) Py_XDECREF(op_A);
Py_XDECREF(matrix_B);
Py_XDECREF(matrix_A);
}
/**
C Implementation of sgemm_ based on NumPy
Used instead of blas when Theano config flag blas.ldflags is empty.
**/
inline PyObject* alt_transpose(PyArrayObject* o) {
return PyArray_Transpose(o, NULL);
}
inline PyObject* alt_matrix_matrix_product(PyObject* o1, PyObject* o2) {
return PyArray_MatrixProduct(o1, o2);
}
inline PyObject* alt_matrix_matrix_product2(PyObject* o1, PyObject* o2, PyObject* out) {
return PyArray_MatrixProduct2(o1, o2, (PyArrayObject*)out);
}
void alt_scalar_matrix_product_in_place(float scalar, float* matrix, int size_to_compute) {
int i;
for(i = 0; i < size_to_compute; ++i) {
matrix[i] *= scalar;
}
}
void alt_matrix_sum_in_place(float* A, float* B, float* out, int size_to_compute) {
int i;
for(i = 0; i < size_to_compute; ++i) {
out[i] = A[i] + B[i];
}
}
inline PyObject* alt_op(char* trans, PyArrayObject* matrix) {
return (*trans == 'N' || *trans == 'n') ? (PyObject*)matrix : alt_transpose(matrix);
}
/* sgemm
* We assume that none of these 13 pointers passed as arguments are null.
* NB: We can optimize this function again (for example, when alpha == 0 and/or beta == 0).
* */
void sgemm_(char* TRANSA, char* TRANSB,
const int* M, const int* N, const int* K,
const float* ALPHA, float* A, const int* LDA,
float* B, const int* LDB, const float* BETA,
float* C, const int* LDC) {
if(*M < 0 || *N < 0 || *K < 0 || *LDA < 0 || *LDB < 0 || *LDC < 0)
return;
if(C == NULL)
return;
/* Recall:
A is a *LDA by *ka matrix.
B is a *LDB by *kb matrix.
C is a *LDC By *N matrix.
*/
int ka, kb;
if(*TRANSA == 'N' || *TRANSA == 'n')
ka = *K;
else
ka = *M;
if(*TRANSB == 'N' || *TRANSB == 'n')
kb = *N;
else
kb = *K;
npy_intp dims_A[2] = {*LDA, ka};
npy_intp dims_B[2] = {*LDB, kb};
PyObject* matrix_A = PyArray_SimpleNewFromData(2, dims_A, NPY_FLOAT32, A);
PyObject* matrix_B = PyArray_SimpleNewFromData(2, dims_B, NPY_FLOAT32, B);
PyObject* op_A = alt_op(TRANSA, (PyArrayObject*)matrix_A);
PyObject* op_B = alt_op(TRANSB, (PyArrayObject*)matrix_B);
if(*BETA == 0) {
npy_intp dims_C[2] = {*LDC, *N};
PyObject* matrix_C = PyArray_SimpleNewFromData(2, dims_C, NPY_FLOAT32, C);
alt_matrix_matrix_product2(op_A, op_B, matrix_C);
alt_scalar_matrix_product_in_place(*ALPHA, C, (*M) * (*N));
Py_XDECREF(matrix_C);
} else {
PyArrayObject* op_A_times_op_B = (PyArrayObject*)alt_matrix_matrix_product(op_A, op_B);
alt_scalar_matrix_product_in_place(*ALPHA, (float*)PyArray_DATA(op_A_times_op_B), (*M) * (*N));
alt_scalar_matrix_product_in_place(*BETA, C, (*M) * (*N));
alt_matrix_sum_in_place((float*)PyArray_DATA(op_A_times_op_B), C, C, (*M) * (*N));
Py_XDECREF(op_A_times_op_B);
}
if(op_B != matrix_B) Py_XDECREF(op_B);
if(op_A != matrix_A) Py_XDECREF(op_A);
Py_XDECREF(matrix_B);
Py_XDECREF(matrix_A);
}
...@@ -729,6 +729,29 @@ def cblas_header_text(): ...@@ -729,6 +729,29 @@ def cblas_header_text():
def blas_header_text(): def blas_header_text():
"""C header for the fortran blas interface""" """C header for the fortran blas interface"""
gemm_code = ""
const = "const"
if not config.blas.ldflags:
# Include the Numpy version implementation of sgemm_ and dgemm_ from alt_sgemm.c and alt_dgemm.c
from os.path import dirname, normpath
current_filedir = dirname(__file__)
sgemm_filepath = normpath(current_filedir + "/alt_sgemm.c")
dgemm_filepath = normpath(current_filedir + "/alt_dgemm.c")
sgemm_code = ""
dgemm_code = ""
with open(sgemm_filepath) as code:
sgemm_code = code.read()
with open(dgemm_filepath) as code:
dgemm_code = code.read()
if not sgemm_code or not dgemm_code:
_logger.warning("Unable to load Numpy implementation of gemm code from C source files.")
else:
const = ""
# _logger.warning("Numpy implementation of gemm code loaded (config.blas.ldflags is empty)")
gemm_code += sgemm_code
gemm_code += dgemm_code
header = """ header = """
extern "C" extern "C"
{ {
...@@ -890,7 +913,7 @@ def blas_header_text(): ...@@ -890,7 +913,7 @@ def blas_header_text():
/* Single Precision */ /* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void sgemm_(char*, char*, const int*, const int*, const int*, const float *, %(const)s float *, const int*, %(const)s float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*); void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
...@@ -899,7 +922,7 @@ def blas_header_text(): ...@@ -899,7 +922,7 @@ def blas_header_text():
/* Double Precision */ /* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); void dgemm_(char*, char*, const int*, const int*, const int*, const double *, %(const)s double *, const int*, %(const)s double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*); void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
...@@ -958,7 +981,9 @@ def blas_header_text(): ...@@ -958,7 +981,9 @@ def blas_header_text():
} }
""") """)
return header
header += gemm_code
return header % {'const':const}
def mkl_threads_text(): def mkl_threads_text():
......
...@@ -63,7 +63,8 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -63,7 +63,8 @@ class BaseCorrMM(gof.OpenMPOp):
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if not theano.config.blas.ldflags: if not theano.config.blas.ldflags:
raise NotImplementedError("C code for corrMM* classes need a blas library.") # raise NotImplementedError("C code for corrMM* classes need a blas library.")
self.blas_type = ''
else: else:
if 'openblas' in theano.config.blas.ldflags: if 'openblas' in theano.config.blas.ldflags:
self.blas_type = 'openblas' self.blas_type = 'openblas'
......
...@@ -72,7 +72,7 @@ compile.optdb.register('local_inplace_sparse_block_outer', ...@@ -72,7 +72,7 @@ compile.optdb.register('local_inplace_sparse_block_outer',
# Conv opts # Conv opts
@local_optimizer([AbstractConv2d]) @local_optimizer([AbstractConv2d])
def local_abstractconv_gemm(node): def local_abstractconv_gemm(node):
if theano.config.cxx == "" or not theano.config.blas.ldflags: if theano.config.cxx == "": # or not theano.config.blas.ldflags:
return return
if not isinstance(node.op, AbstractConv2d): if not isinstance(node.op, AbstractConv2d):
return None return None
...@@ -116,7 +116,7 @@ def local_abstractconv3d_gemm(node): ...@@ -116,7 +116,7 @@ def local_abstractconv3d_gemm(node):
@local_optimizer([AbstractConv2d_gradWeights]) @local_optimizer([AbstractConv2d_gradWeights])
def local_abstractconv_gradweight_gemm(node): def local_abstractconv_gradweight_gemm(node):
if theano.config.cxx == "" or not theano.config.blas.ldflags: if theano.config.cxx == "" : # or not theano.config.blas.ldflags:
return return
if not isinstance(node.op, AbstractConv2d_gradWeights): if not isinstance(node.op, AbstractConv2d_gradWeights):
return None return None
...@@ -166,7 +166,7 @@ def local_abstractconv3d_gradweight_gemm(node): ...@@ -166,7 +166,7 @@ def local_abstractconv3d_gradweight_gemm(node):
@local_optimizer([AbstractConv2d_gradInputs]) @local_optimizer([AbstractConv2d_gradInputs])
def local_abstractconv_gradinputs_gemm(node): def local_abstractconv_gradinputs_gemm(node):
if theano.config.cxx == "" or not theano.config.blas.ldflags: if theano.config.cxx == "" : # or not theano.config.blas.ldflags:
return return
if not isinstance(node.op, AbstractConv2d_gradInputs): if not isinstance(node.op, AbstractConv2d_gradInputs):
return None return None
...@@ -603,6 +603,5 @@ def local_abstractconv_check(node): ...@@ -603,6 +603,5 @@ def local_abstractconv_check(node):
node.op.__class__.__name__) node.op.__class__.__name__)
optdb.register('AbstractConvCheck', optdb.register('AbstractConvCheck',
opt.in2out(local_abstractconv_check, opt.in2out(local_abstractconv_check, name="AbstractConvCheck"),
name="AbstractConvCheck"),
48.7, 'fast_compile', 'fast_run') 48.7, 'fast_compile', 'fast_run')
...@@ -414,14 +414,13 @@ class BaseTestConv2d(BaseTestConv): ...@@ -414,14 +414,13 @@ class BaseTestConv2d(BaseTestConv):
class TestCorrConv2d(BaseTestConv2d): class TestCorrConv2d(BaseTestConv2d):
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
if theano.config.blas.ldflags == "": # if theano.config.blas.ldflags == "": raise SkipTest()
raise SkipTest()
BaseTestConv2d.setup_class() BaseTestConv2d.setup_class()
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1)): def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1)):
o = self.get_output_shape(i, f, s, b, fd) o = self.get_output_shape(i, f, s, b, fd)
if (not theano.config.blas.ldflags or #if (not theano.config.blas.ldflags or
not theano.config.cxx or if (not theano.config.cxx or
theano.config.mode == "FAST_COMPILE"): theano.config.mode == "FAST_COMPILE"):
raise SkipTest("Need blas to test conv2d") raise SkipTest("Need blas to test conv2d")
self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s, self.run_fwd(inputs_shape=i, filters_shape=f, subsample=s,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论