Fix ctc_enabled detection and correct headers and library paths of CTC Ops

上级 ecabd259
...@@ -50,14 +50,17 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -50,14 +50,17 @@ class GpuConnectionistTemporalClassification(gof.COp):
gof.COp.__init__(self, self.func_file, self.func_name) gof.COp.__init__(self, self.func_file, self.func_name)
def c_lib_dirs(self): def c_lib_dirs(self):
assert ctc_available.path is not None lib_dirs = []
return [ctc_available.path] if ctc_available.path is not None:
lib_dirs += [ctc_available.path]
return lib_dirs
def c_libraries(self): def c_libraries(self):
return ["warpctc", "gpuarray"] return ["warpctc", "gpuarray"]
def c_header_dirs(self): def c_header_dirs(self):
dirs = [os.path.dirname(__file__), pygpu.get_include()] dirs = [os.path.dirname(__file__), pygpu.get_include()]
if config.ctc.root != '':
dirs.append(os.path.join(config.ctc.root, "include")) dirs.append(os.path.join(config.ctc.root, "include"))
return dirs return dirs
......
...@@ -39,9 +39,9 @@ options.loc = CTC_CPU; ...@@ -39,9 +39,9 @@ options.loc = CTC_CPU;
options.num_threads = 1; options.num_threads = 1;
""" """
params = ["-I%s" % (os.path.join(config.ctc.root, "include"))] params = ['-I%s' % (os.path.dirname(__file__))]
params.extend(['-I%s' % (os.path.dirname(__file__))])
if ctc_lib_path is not None: if ctc_lib_path is not None:
params.extend(["-I%s" % (os.path.join(config.ctc.root, "include"))])
params.extend(["-L%s" % (ctc_lib_path)]) params.extend(["-L%s" % (ctc_lib_path)])
params.extend(["-l", "warpctc"]) params.extend(["-l", "warpctc"])
compiler_res = GCC_compiler.try_flags( compiler_res = GCC_compiler.try_flags(
...@@ -70,12 +70,7 @@ ctc_present.path = None ...@@ -70,12 +70,7 @@ ctc_present.path = None
def ctc_available(): def ctc_available():
if config.ctc.root == '': if os.name == 'nt':
ctc_available.msg = 'ctc.root variable is not set, please set it ',
'to the root directory of the CTC library in ',
'your system.'
return False
elif os.name == 'nt':
ctc_available.msg = 'Windows platforms are currently not supported ', ctc_available.msg = 'Windows platforms are currently not supported ',
'by underlying CTC library (warp-ctc).' 'by underlying CTC library (warp-ctc).'
return False return False
...@@ -127,16 +122,21 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -127,16 +122,21 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
self.default_output = 0 self.default_output = 0
def c_lib_dirs(self): def c_lib_dirs(self):
assert ctc_available.path is not None lib_dirs = []
return [ctc_available.path] if ctc_available.path is not None:
lib_dirs += [ctc_available.path]
return lib_dirs
def c_libraries(self): def c_libraries(self):
return ["warpctc"] return ["warpctc"]
def c_header_dirs(self): def c_header_dirs(self):
header_dirs = []
if config.ctc.root != '':
# We assume here that the header is available at the include directory # We assume here that the header is available at the include directory
# of the CTC root directory. # of the CTC root directory.
return [os.path.join(config.ctc.root, "include")] header_dirs += [os.path.join(config.ctc.root, "include")]
return header_dirs
def c_compile_args(self): def c_compile_args(self):
return gof.OpenMPOp.c_compile_args(self) return gof.OpenMPOp.c_compile_args(self)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论