Add a check to a list of subdirectories to include libwarpctc in c_lib_dirs

上级 664ec289
......@@ -1773,8 +1773,8 @@ AddConfigVar(
AddConfigVar(
'ctc.root',
'Directory which contains the root of Baidu CTC library. It is assumed \
that the compiled library is inside the build directory, and the header \
inside the include directory.',
that the compiled library is either inside the build, lib or lib64 \
subdirectory, and the header inside the include directory.',
StrParam('', allow_override=False))
# Check if there are remaining flags provided by the user through THEANO_FLAGS.
......
......@@ -13,6 +13,7 @@ from theano.tensor.opt import register_canonicalize
from theano.tensor.opt import register_stabilize
import os
import os.path
from . import pygpu
ctc_enabled = config.ctc.enabled
......@@ -53,9 +54,19 @@ class GpuConnectionistTemporalClassification(gof.COp):
def c_lib_dirs(self):
dirs = []
if ctc_enabled:
# We assume here that the compiled library (libwarpctc.so) is available
# at the build directory of the CTC root directory.
dirs.append(os.path.join(config.ctc.root, "build"))
# Find the directory that contains libwarpctc.so
lib_found = False
for lib_dir in ["build", "lib", "lib64"]:
lib_path = os.path.join(config.ctc.root, lib_dir)
if os.path.isdir(lib_path) and os.path.exists(lib_path):
lib_found = os.path.exists(os.path.join(lib_path, "libwarpctc.so"))
if lib_found:
dirs.append(lib_path)
break
if not lib_found:
raise RuntimeError('libwarpctc.so could not be found. ',
'Please check the config.ctc.root variable.')
return dirs
def c_libraries(self):
......
from __future__ import (division, absolute_import, print_function)
import os
import os.path
import theano.tensor as T
from theano import config
from theano import gof
......@@ -53,9 +54,19 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
def c_lib_dirs(self):
dirs = []
if ctc_enabled:
# We assume here that the compiled library (libwarpctc.so) is available
# at the build directory of the CTC root directory.
dirs.append(os.path.join(config.ctc.root, "build"))
# Find the directory that contains libwarpctc.so
lib_found = False
for lib_dir in ["build", "lib", "lib64"]:
lib_path = os.path.join(config.ctc.root, lib_dir)
if os.path.isdir(lib_path) and os.path.exists(lib_path):
lib_found = os.path.exists(os.path.join(lib_path, "libwarpctc.so"))
if lib_found:
dirs.append(lib_path)
break
if not lib_found:
raise RuntimeError('libwarpctc.so could not be found. ',
'Please check the config.ctc.root variable.')
return dirs
def c_libraries(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论