提交 70896325 authored 作者: notoraptor's avatar notoraptor

New update.

* Removed other checkings of empty blas.ldflags in blas.py and test_blas.py. * Correction in alt_gemm_template.c to prevent some test failures in test_blas.py (occuring after above modification): in some tests, a zero-content matrix is passed as C (with M*N == 0). Now if we encounter this case, we just skip it (gemm not calculated, C not modified). * Correction in blas_headers.py to prevent some old string- formating errors when strings are C code containing "%" symbols. Now test_blas.py runs perfectly. I have to re-run all other tests tonight.
上级 24f96fa8
......@@ -73,7 +73,12 @@ void %(name)s(
const %(float_type)s* ALPHA, %(float_type)s* A, const int* LDA,
%(float_type)s* B, const int* LDB, const %(float_type)s* BETA,
%(float_type)s* C, const int* LDC) {
if(*M < 0 || *N < 0 || *K < 0 || *LDA < 0 || *LDB < 0 || *LDC < 0)
/* NB: it seems that matrix+matrix and scalar*matrix functions
* defined above do not allocate iterator for a matrix with 0
* content, that is a matrix whose nrow*ncol == 0. As these
* functions actually work with M*N matrices (op(A)*op(B) and/or C),
* I think that we could just return if M or N is null. */
if(*M < 1 || *N < 1 || *K < 0 || *LDA < 0 || *LDB < 0 || *LDC < 0)
return;
int nrowa, ncola, nrowb, ncolb;
int is_A_transposable = alt_trans_to_bool(TRANSA);
......
......@@ -1037,10 +1037,7 @@ class Gemm(GemmRelated):
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__)
if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name,
(_z, _a, _x, _y, _b), (_zout, ),
sub)
# if not config.blas.ldflags: # return super(Gemm, self).c_code(node, name, (_z, _a, _x, _y, _b), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
......@@ -2154,9 +2151,7 @@ class BatchedDot(Op):
_z, = out
fail = sub["fail"]
if not config.blas.ldflags:
return super(BatchedDot, self).c_code(node, name,
inp, out, sub)
# if not config.blas.ldflags: # return super(BatchedDot, self).c_code(node, name, inp, out, sub)
# generate contiguity condition
def contiguous(var, ndim):
......
......@@ -985,8 +985,7 @@ def blas_header_text():
}
""")
header += gemm_code
return header % {'const': const}
return (header % {'const': const}) + gemm_code
def mkl_threads_text():
......
......@@ -95,7 +95,8 @@ class t_gemm(TestCase):
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'py')
if (config.blas.ldflags and not dtype.startswith("complex")
# if (config.blas.ldflags and not dtype.startswith("complex")
if (not dtype.startswith("complex")
and theano.config.cxx):
# If blas.ldflags is equal to '', the C code will not
# be generated
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论