Fix ctc_enabled detection and correct headers and library paths of CTC Ops

上级 ecabd259
......@@ -50,15 +50,18 @@ class GpuConnectionistTemporalClassification(gof.COp):
gof.COp.__init__(self, self.func_file, self.func_name)
def c_lib_dirs(self):
assert ctc_available.path is not None
return [ctc_available.path]
lib_dirs = []
if ctc_available.path is not None:
lib_dirs += [ctc_available.path]
return lib_dirs
def c_libraries(self):
return ["warpctc", "gpuarray"]
def c_header_dirs(self):
dirs = [os.path.dirname(__file__), pygpu.get_include()]
dirs.append(os.path.join(config.ctc.root, "include"))
if config.ctc.root != '':
dirs.append(os.path.join(config.ctc.root, "include"))
return dirs
def c_headers(self):
......
......@@ -39,9 +39,9 @@ options.loc = CTC_CPU;
options.num_threads = 1;
"""
params = ["-I%s" % (os.path.join(config.ctc.root, "include"))]
params.extend(['-I%s' % (os.path.dirname(__file__))])
params = ['-I%s' % (os.path.dirname(__file__))]
if ctc_lib_path is not None:
params.extend(["-I%s" % (os.path.join(config.ctc.root, "include"))])
params.extend(["-L%s" % (ctc_lib_path)])
params.extend(["-l", "warpctc"])
compiler_res = GCC_compiler.try_flags(
......@@ -70,12 +70,7 @@ ctc_present.path = None
def ctc_available():
if config.ctc.root == '':
ctc_available.msg = 'ctc.root variable is not set, please set it ',
'to the root directory of the CTC library in ',
'your system.'
return False
elif os.name == 'nt':
if os.name == 'nt':
ctc_available.msg = 'Windows platforms are currently not supported ',
'by underlying CTC library (warp-ctc).'
return False
......@@ -127,16 +122,21 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
self.default_output = 0
def c_lib_dirs(self):
assert ctc_available.path is not None
return [ctc_available.path]
lib_dirs = []
if ctc_available.path is not None:
lib_dirs += [ctc_available.path]
return lib_dirs
def c_libraries(self):
return ["warpctc"]
def c_header_dirs(self):
# We assume here that the header is available at the include directory
# of the CTC root directory.
return [os.path.join(config.ctc.root, "include")]
header_dirs = []
if config.ctc.root != '':
# We assume here that the header is available at the include directory
# of the CTC root directory.
header_dirs += [os.path.join(config.ctc.root, "include")]
return header_dirs
def c_compile_args(self):
return gof.OpenMPOp.c_compile_args(self)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论