提交 1e74be55 authored 作者: Frederic Bastien's avatar Frederic Bastien

put the default value of THEANO_BLAS_LDFLAGS into config.py and changed so that…

put the default value of THEANO_BLAS_LDFLAGS into config.py and changed so that if empty, we will use numpy.dot(no c_code)
上级 f378868a
......@@ -22,7 +22,8 @@ HOME = os.getenv('HOME')
THEANO_CMP_SLOPPY = os.getenv('THEANO_CMP_SLOPPY', 0)
#flag for compiling with an optimized blas library. Used for gemm operation
THEANO_BLAS_LDFLAGS = os.getenv('THEANO_BLAS_LDFLAGS')
#if THEANO_BLAS_LDFLAGS exist but empty, we will use numpy.dot()
THEANO_BLAS_LDFLAGS = os.getenv('THEANO_BLAS_LDFLAGS','-lblas')
#for gpu
CUDA_ROOT = os.getenv('CUDA_ROOT')
......
......@@ -33,22 +33,18 @@ def ldflags(libs=True, flags=False):
Default: ['blas'], but environment variable THEANO_BLAS_LDFLAGS overrides this.
"""
rval = []
if config.THEANO_BLAS_LDFLAGS:
tokens = config.THEANO_BLAS_LDFLAGS.split()
for t in tokens:
try:
t0, t1, t2 = t[0:3]
assert t0 == '-'
except:
raise ValueError('invalid token in THEANO_BLAS_LDFLAGS', t)
if t1 == 'L':
raise ValueError('library dir not allowed in THEANO_BLAS_LDFLAGS', t)
elif libs and t1=='l': # example -lmkl
rval.append(t[2:])
elif flags and t1!='l': # example -openmp
rval.append(t)
elif libs:
rval = ['blas']
for t in config.THEANO_BLAS_LDFLAGS.split():
try:
t0, t1, t2 = t[0:3]
assert t0 == '-'
except:
raise ValueError('invalid token in THEANO_BLAS_LDFLAGS', t)
if t1 == 'L':
raise ValueError('library dir not allowed in THEANO_BLAS_LDFLAGS', t)
elif libs and t1=='l': # example -lmkl
rval.append(t[2:])
elif flags and t1!='l': # example -openmp
rval.append(t)
#print "blas linking against", rval
return rval
......@@ -371,6 +367,8 @@ class Gemm(GemmRelated):
"""
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): #DEBUG
if len(self.c_libraries())<=0:
return super(Gemm, self).c_code(node, name, (_z, _a, _x, _y, _b), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
......@@ -656,6 +654,8 @@ class Dot22(GemmRelated):
double b = 0.0;
"""
def c_code(self, node, name, (_x, _y), (_z, ), sub): #DEBUG
if len(self.c_libraries())<=0:
return super(Dot22, self).c_code(node, name, (_x, _y), (_z, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论