Add a check to a list of subdirectories to include libwarpctc in c_lib_dirs

上级 664ec289
...@@ -1773,8 +1773,8 @@ AddConfigVar( ...@@ -1773,8 +1773,8 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
'ctc.root', 'ctc.root',
'Directory which contains the root of Baidu CTC library. It is assumed \ 'Directory which contains the root of Baidu CTC library. It is assumed \
that the compiled library is inside the build directory, and the header \ that the compiled library is either inside the build, lib or lib64 \
inside the include directory.', subdirectory, and the header inside the include directory.',
StrParam('', allow_override=False)) StrParam('', allow_override=False))
# Check if there are remaining flags provided by the user through THEANO_FLAGS. # Check if there are remaining flags provided by the user through THEANO_FLAGS.
......
...@@ -13,6 +13,7 @@ from theano.tensor.opt import register_canonicalize ...@@ -13,6 +13,7 @@ from theano.tensor.opt import register_canonicalize
from theano.tensor.opt import register_stabilize from theano.tensor.opt import register_stabilize
import os import os
import os.path
from . import pygpu from . import pygpu
ctc_enabled = config.ctc.enabled ctc_enabled = config.ctc.enabled
...@@ -53,9 +54,19 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -53,9 +54,19 @@ class GpuConnectionistTemporalClassification(gof.COp):
def c_lib_dirs(self): def c_lib_dirs(self):
dirs = [] dirs = []
if ctc_enabled: if ctc_enabled:
# We assume here that the compiled library (libwarpctc.so) is available # Find the directory that contains libwarpctc.so
# at the build directory of the CTC root directory. lib_found = False
dirs.append(os.path.join(config.ctc.root, "build")) for lib_dir in ["build", "lib", "lib64"]:
lib_path = os.path.join(config.ctc.root, lib_dir)
if os.path.isdir(lib_path) and os.path.exists(lib_path):
lib_found = os.path.exists(os.path.join(lib_path, "libwarpctc.so"))
if lib_found:
dirs.append(lib_path)
break
if not lib_found:
raise RuntimeError('libwarpctc.so could not be found. ',
'Please check the config.ctc.root variable.')
return dirs return dirs
def c_libraries(self): def c_libraries(self):
......
from __future__ import (division, absolute_import, print_function) from __future__ import (division, absolute_import, print_function)
import os import os
import os.path
import theano.tensor as T import theano.tensor as T
from theano import config from theano import config
from theano import gof from theano import gof
...@@ -53,9 +54,19 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -53,9 +54,19 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
def c_lib_dirs(self): def c_lib_dirs(self):
dirs = [] dirs = []
if ctc_enabled: if ctc_enabled:
# We assume here that the compiled library (libwarpctc.so) is available # Find the directory that contains libwarpctc.so
# at the build directory of the CTC root directory. lib_found = False
dirs.append(os.path.join(config.ctc.root, "build")) for lib_dir in ["build", "lib", "lib64"]:
lib_path = os.path.join(config.ctc.root, lib_dir)
if os.path.isdir(lib_path) and os.path.exists(lib_path):
lib_found = os.path.exists(os.path.join(lib_path, "libwarpctc.so"))
if lib_found:
dirs.append(lib_path)
break
if not lib_found:
raise RuntimeError('libwarpctc.so could not be found. ',
'Please check the config.ctc.root variable.')
return dirs return dirs
def c_libraries(self): def c_libraries(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论