提交 a35e0c11 authored 作者: João Victor Tozatti Risso's avatar João Victor Tozatti Risso 提交者: João Victor Tozatti Risso

Add pygpu and CKernelBase headers in GPU CTC implementation

上级 e2c9abc4
......@@ -12,6 +12,9 @@ from .type import GpuArrayType
from .opt import register_opt, op_lifter, register_opt2
from theano.gradient import grad_undefined
import os
import pygpu
ctc_enabled = config.ctc.enabled
......@@ -30,6 +33,7 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op):
self.compute_grad = compute_grad
self.context_name = context_name
Op.__init__(self)
CGpuKernelBase.__init__(self, self.func_file, self.func_name)
self.costs_type = GpuArrayType(dtype='float32',
......@@ -63,15 +67,20 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op):
# We assume here that the header is available at the include directory
# of the CTC root directory.
dirs.append(os.path.join(config.ctc.root, "include"))
return dirs + CGpuKernelBase.c_header_dirs(self)
dirs = dirs + list(pygpu.get_include())
dirs = dirs + list(super(CGpuKernelBase, self).c_header_dirs())
return dirs
def c_headers(self):
return ["ctc.h"] + CGpuKernelBase.c_headers(self)
headers = ['ctc.h']
headers = headers + super(CGpuKernelBase, self).c_headers()
headers = headers + ['<numpy_compat.h>', '<gpuarray_helper.h>']
return headers
def make_node(self, activations, labels, input_lengths):
if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and '
'ConnectionistTemporalClassification Op '
'GpuConnectionistTemporalClassification Op '
'can not be constructed.')
context = infer_context_name(activations, labels, input_lengths)
......@@ -106,10 +115,10 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op):
outputs=out_params)
def grad(self, inputs, output_grads):
return [grad_undefined(self, 0, inputs[0]),
return [self.grads_type(),
grad_undefined(self, 1, inputs[1]),
grad_undefined(self, 2, inputs[2])]
def ctc(activations, labels, input_lengths):
return GpuConnectionistTemporalClassification()(activations, labels,
input_lengths)
\ No newline at end of file
input_lengths)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论