Remove explicit cuda runtime dependencies from ctc wrapper

上级 da400367
...@@ -72,12 +72,11 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op): ...@@ -72,12 +72,11 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op):
# We assume here that the header is available at the include directory # We assume here that the header is available at the include directory
# of the CTC root directory. # of the CTC root directory.
dirs.append(os.path.join(config.ctc.root, "include")) dirs.append(os.path.join(config.ctc.root, "include"))
dirs.append('/usr/local/cuda/include')
return dirs return dirs
def c_headers(self): def c_headers(self):
return ['ctc.h', 'numpy_compat.h', 'gpuarray_helper.h', 'gpuarray/types.h', return ['ctc.h', 'numpy_compat.h', 'gpuarray_helper.h', 'gpuarray/types.h',
'gpuarray_api.h', 'gpuarray/array.h', 'gpuarray/util.h', '<cuda_runtime.h>'] 'gpuarray_api.h', 'gpuarray/array.h', 'gpuarray/util.h']
def make_node(self, activations, labels, input_lengths): def make_node(self, activations, labels, input_lengths):
if not ctc_enabled: if not ctc_enabled:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论