Add ctc_available function to auto-detect warp-ctc availability

上级 d131cf93
...@@ -11,13 +11,12 @@ from theano.gradient import grad_undefined ...@@ -11,13 +11,12 @@ from theano.gradient import grad_undefined
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.tensor.opt import register_canonicalize from theano.tensor.opt import register_canonicalize
from theano.tensor.opt import register_stabilize from theano.tensor.opt import register_stabilize
from theano.tensor.nnet.ctc import ctc_available
import os import os
import os.path import os.path
from . import pygpu from . import pygpu
ctc_enabled = config.ctc.enabled
class GpuConnectionistTemporalClassification(gof.COp): class GpuConnectionistTemporalClassification(gof.COp):
""" """
...@@ -39,14 +38,10 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -39,14 +38,10 @@ class GpuConnectionistTemporalClassification(gof.COp):
params_type = gpu_context_type params_type = gpu_context_type
def __init__(self, compute_grad=True): def __init__(self, compute_grad=True):
if not ctc_enabled: if not ctc_available():
raise RuntimeError('Baidu CTC is not enabled and ' raise RuntimeError('Baidu CTC is not available and '
'GpuConnectionistTemporalClassification Op ' 'GpuConnectionistTemporalClassification Op '
'can not be constructed.') 'can not be constructed.')
elif config.ctc.root == "":
raise ValueError('ctc.root variable is not set, please set it '
'to the root directory of the CTC library in '
'your system.')
self.compute_grad = compute_grad self.compute_grad = compute_grad
# Return only the cost. Gradient will be returned by grad() # Return only the cost. Gradient will be returned by grad()
...@@ -55,32 +50,15 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -55,32 +50,15 @@ class GpuConnectionistTemporalClassification(gof.COp):
gof.COp.__init__(self, self.func_file, self.func_name) gof.COp.__init__(self, self.func_file, self.func_name)
def c_lib_dirs(self): def c_lib_dirs(self):
dirs = [] assert ctc_available.path is not None
if ctc_enabled: return [ctc_available.path]
# Find the directory that contains libwarpctc.so
lib_found = False
for lib_dir in ["build", "lib", "lib64"]:
lib_path = os.path.join(config.ctc.root, lib_dir)
if os.path.isdir(lib_path) and os.path.exists(lib_path):
lib_found = os.path.exists(os.path.join(lib_path, "libwarpctc.so"))
if lib_found:
dirs.append(lib_path)
break
if not lib_found:
raise RuntimeError('libwarpctc.so could not be found. ',
'Please check the config.ctc.root variable.')
return dirs
def c_libraries(self): def c_libraries(self):
return ["warpctc", "gpuarray"] return ["warpctc", "gpuarray"]
def c_header_dirs(self): def c_header_dirs(self):
dirs = [os.path.dirname(__file__), pygpu.get_include()] dirs = [os.path.dirname(__file__), pygpu.get_include()]
if ctc_enabled: dirs.append(os.path.join(config.ctc.root, "include"))
# We assume here that the header is available at the include directory
# of the CTC root directory.
dirs.append(os.path.join(config.ctc.root, "include"))
return dirs return dirs
def c_headers(self): def c_headers(self):
......
...@@ -7,14 +7,14 @@ import theano ...@@ -7,14 +7,14 @@ import theano
import theano.tensor as T import theano.tensor as T
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
import theano.gpuarray import theano.gpuarray
from theano.gpuarray.ctc import (ctc_enabled, gpu_ctc, GpuConnectionistTemporalClassification) from theano.gpuarray.ctc import (gpu_ctc, GpuConnectionistTemporalClassification)
from theano.tensor.nnet.ctc import (ctc, ConnectionistTemporalClassification) from theano.tensor.nnet.ctc import (ctc, ctc_available, ConnectionistTemporalClassification)
from .config import (mode_with_gpu, mode_without_gpu) from .config import (mode_with_gpu, mode_without_gpu)
class TestCTC(unittest.TestCase): class TestCTC(unittest.TestCase):
def setUp(self): def setUp(self):
if not ctc_enabled: if not ctc_available():
self.skipTest('Optional library warp-ctc not available') self.skipTest('Optional library warp-ctc not available')
def check_ctc(self, activations, labels, input_length, expected_costs, expected_grads): def check_ctc(self, activations, labels, input_length, expected_costs, expected_grads):
......
...@@ -5,12 +5,88 @@ import theano.tensor as T ...@@ -5,12 +5,88 @@ import theano.tensor as T
from theano import config from theano import config
from theano import gof from theano import gof
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.gof.cmodule import GCC_compiler
from theano.tensor.opt import register_canonicalize from theano.tensor.opt import register_canonicalize
from theano.tensor.opt import register_stabilize from theano.tensor.opt import register_stabilize
from theano.tensor.extra_ops import cpu_contiguous from theano.tensor.extra_ops import cpu_contiguous
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
ctc_enabled = config.ctc.enabled
def _ctc_find_lib():
"""
Find the directory that contains libwarpctc.so
"""
lib_found = False
for lib_dir in ["build", "lib", "lib64"]:
lib_path = os.path.join(config.ctc.root, lib_dir)
if os.path.isdir(lib_path) and os.path.exists(lib_path):
lib_found = os.path.exists(os.path.join(lib_path, "libwarpctc.so"))
if lib_found:
return True, lib_path
return False, None
def _ctc_check_compile(ctc_lib_path):
preambule = """
#include <string.h>
#include "ctc.h"
"""
body = """
ctcOptions options;
memset(&options, 0, sizeof(ctcOptions));
options.loc = CTC_CPU;
options.num_threads = 1;
"""
params = ["-I%s" % (os.path.join(config.ctc.root, "include"))]
params.extend(['-I%s' % (os.path.dirname(__file__))])
params.extend(["-L%s" % (ctc_lib_path)])
params.extend(["-l", "warpctc"])
compiler_res = GCC_compiler.try_flags(
params, preambule=preambule, body=body,
try_run=False, output=True)
avail, out, err = compiler_res if isinstance(compiler_res, tuple) else (compiler_res, None, None)
if not avail:
return False, ("cannot compile with warp-ctc. "
"We got this error:\n" + str(err))
return True, None
def ctc_present():
if ctc_present.avail is not None:
return ctc_present.avail
ctc_present.avail, ctc_lib_path = _ctc_find_lib()
if ctc_lib_path is None:
ctc_present.msg = 'libwarpctc.so could not be found. ',
'Please check your config.ctc.root variable.'
else:
ctc_present.path = ctc_lib_path
ctc_present.avail, ctc_present.msg = _ctc_check_compile(ctc_present.path)
return ctc_present.avail
ctc_present.avail = None
ctc_present.msg = None
ctc_present.path = None
def ctc_available():
if config.ctc.root == '':
ctc_available.msg = 'ctc.root variable is not set, please set it ',
'to the root directory of the CTC library in ',
'your system.'
return False
elif os.name == 'nt':
ctc_available.msg = 'Windows platforms are currently not supported ',
'by underlying CTC library (warp-ctc).'
return False
elif not ctc_present():
ctc_available.msg = ctc_present.msg
return False
ctc_available.path = ctc_present.path
return True
ctc_available.msg = None
ctc_available.path = None
class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
...@@ -37,14 +113,10 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -37,14 +113,10 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
func_name = "APPLY_SPECIFIC(ctc_cost_cpu)" func_name = "APPLY_SPECIFIC(ctc_cost_cpu)"
def __init__(self, compute_grad=True): def __init__(self, compute_grad=True):
if not ctc_enabled: if not ctc_available():
raise RuntimeError('Baidu CTC is not enabled and ' raise RuntimeError('Baidu CTC is not available and '
'ConnectionistTemporalClassification Op ' 'ConnectionistTemporalClassification Op '
'can not be constructed.') 'can not be constructed.')
elif config.ctc.root == "":
raise ValueError('ctc.root variable is not set, please set it '
'to the root directory of the CTC library in '
'your system.')
gof.COp.__init__(self, self.func_file, self.func_name) gof.COp.__init__(self, self.func_file, self.func_name)
gof.OpenMPOp.__init__(self) gof.OpenMPOp.__init__(self)
...@@ -54,33 +126,16 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -54,33 +126,16 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
self.default_output = 0 self.default_output = 0
def c_lib_dirs(self): def c_lib_dirs(self):
dirs = [] assert ctc_available.path is not None
if ctc_enabled: return [ctc_available.path]
# Find the directory that contains libwarpctc.so
lib_found = False
for lib_dir in ["build", "lib", "lib64"]:
lib_path = os.path.join(config.ctc.root, lib_dir)
if os.path.isdir(lib_path) and os.path.exists(lib_path):
lib_found = os.path.exists(os.path.join(lib_path, "libwarpctc.so"))
if lib_found:
dirs.append(lib_path)
break
if not lib_found:
raise RuntimeError('libwarpctc.so could not be found. ',
'Please check the config.ctc.root variable.')
return dirs
def c_libraries(self): def c_libraries(self):
return ["warpctc"] return ["warpctc"]
def c_header_dirs(self): def c_header_dirs(self):
dirs = [] # We assume here that the header is available at the include directory
if ctc_enabled: # of the CTC root directory.
# We assume here that the header is available at the include directory return [os.path.join(config.ctc.root, "include")]
# of the CTC root directory.
dirs.append(os.path.join(config.ctc.root, "include"))
return dirs
def c_compile_args(self): def c_compile_args(self):
return gof.OpenMPOp.c_compile_args(self) return gof.OpenMPOp.c_compile_args(self)
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import theano import theano
import theano.tensor as T import theano.tensor as T
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet.ctc import (ctc_enabled, ctc, ConnectionistTemporalClassification) from theano.tensor.nnet.ctc import (ctc_available, ctc, ConnectionistTemporalClassification)
class TestCTC(unittest.TestCase): class TestCTC(unittest.TestCase):
...@@ -18,7 +18,7 @@ class TestCTC(unittest.TestCase): ...@@ -18,7 +18,7 @@ class TestCTC(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
if not ctc_enabled: if not ctc_available():
self.skipTest('Optional library warp-ctc not available') self.skipTest('Optional library warp-ctc not available')
def run_ctc(self, activations, labels, input_length, expected_costs, expected_grads): def run_ctc(self, activations, labels, input_length, expected_costs, expected_grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论