Fix grad computation and add basic docstring to ctc wrapper

上级 634001c0
...@@ -13,6 +13,16 @@ ctc_enabled = config.ctc.enabled ...@@ -13,6 +13,16 @@ ctc_enabled = config.ctc.enabled
class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
"""
CTC loss function wrapper.
Notes
-----
Using the wrapper requires that Baidu's warp-ctc library is installed and the
configuration variables `config.ctc.enabled` and `config.ctc.root` be properly
set.
"""
__props__ = ('compute_grad',) __props__ = ('compute_grad',)
func_file = "./ctc_wrapper.c" func_file = "./ctc_wrapper.c"
...@@ -61,6 +71,23 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -61,6 +71,23 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
return ["ctc.h"] + gof.OpenMPOp.c_headers(self) return ["ctc.h"] + gof.OpenMPOp.c_headers(self)
def make_node(self, activations, labels, input_lengths): def make_node(self, activations, labels, input_lengths):
"""
Parameters
----------
activations
Three-dimensional tensor, which has a shape of (t, m, p), where
t is the time index, m is the minibatch index, and p is the index
over the probabilities of each symbol in the alphabet. The memory
layout is assumed to be in C-order, which consists in the slowest
to the fastest changing dimension, from left to right. In this case,
p is the fastest changing dimension.
labels
A 1-D tensor of all the labels for the minibatch.
input_lengths
A 1-D tensor with the number of time steps for each sequence in
the minibatch.
"""
if not ctc_enabled: if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and ' raise RuntimeError('Baidu CTC is not enabled and '
'ConnectionistTemporalClassification Op ' 'ConnectionistTemporalClassification Op '
...@@ -92,7 +119,9 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -92,7 +119,9 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
raise RuntimeError('Baidu CTC is not enabled and ' raise RuntimeError('Baidu CTC is not enabled and '
'ConnectionistTemporalClassification Op ' 'ConnectionistTemporalClassification Op '
'can not be constructed.') 'can not be constructed.')
return [self.gradients, grad_op = output_grads[0]
total_grad = T.basic.batched_dot(grad_op, self.gradients.dimshuffle(1, 0, 2)).dimshuffle(1, 0, 2)
return [total_grad,
grad_undefined(self, 1, inputs[1]), grad_undefined(self, 1, inputs[1]),
grad_undefined(self, 2, inputs[2])] grad_undefined(self, 2, inputs[2])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论