提交 a29349fb authored 作者: Frederic's avatar Frederic

make dnn detect if the library can be loaded.

上级 9535b59f
...@@ -1845,7 +1845,7 @@ class GCC_compiler(object): ...@@ -1845,7 +1845,7 @@ class GCC_compiler(object):
return (compilation_ok, run_ok) return (compilation_ok, run_ok)
@staticmethod @staticmethod
def try_flags(flag_list): def try_flags(flag_list, try_run=False, preambule=""):
''' '''
Try to compile a dummy file with these flags. Try to compile a dummy file with these flags.
...@@ -1856,13 +1856,14 @@ class GCC_compiler(object): ...@@ -1856,13 +1856,14 @@ class GCC_compiler(object):
return False return False
code = b(""" code = b("""
%(preambule)s
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
return 0; return 0;
} }
""") """ % locals())
return GCC_compiler.try_compile_tmp(code, tmp_prefix='try_flags_', return GCC_compiler.try_compile_tmp(code, tmp_prefix='try_flags_',
flags=flag_list, try_run=False) flags=flag_list, try_run=try_run)
@staticmethod @staticmethod
def compile_str(module_name, src_code, location=None, def compile_str(module_name, src_code, location=None,
......
import os import os
import theano import theano
from theano import Apply, tensor from theano import Apply, gof, tensor
from theano.gof import Optimizer from theano.gof import Optimizer
from theano.gof.type import CDataType from theano.gof.type import CDataType
from theano.compat import PY3 from theano.compat import PY3
...@@ -26,8 +26,8 @@ def dnn_available(): ...@@ -26,8 +26,8 @@ def dnn_available():
dnn_available.avail = False dnn_available.avail = False
else: else:
dnn_available.msg = "Can not find the cuDNN library" dnn_available.msg = "Can not find the cuDNN library"
dnn_available.avail = theano.gof.cmodule.GCC_compiler.try_flags( dnn_available.avail = all(gof.cmodule.GCC_compiler.try_flags(
["-l", "cudnn"]) ["-l", "cudnn"], try_run=True, preambule="#include <cudnn.h>"))
return dnn_available.avail return dnn_available.avail
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论