Move docstring from CTC class to ctc function and fix docstring of CTC class

上级 3bc8cd31
......@@ -24,24 +24,12 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
Parameters
----------
activations
Three-dimensional tensor, which has a shape of (t, m, p), where
t is the time index, m is the minibatch index, and p is the index
over the probabilities of each symbol in the alphabet. The memory
layout is assumed to be in C-order, which consists in the slowest
to the fastest changing dimension, from left to right. In this case,
p is the fastest changing dimension.
labels
A 1-D tensor of all the labels for the minibatch.
input_lengths
A 1-D tensor with the number of time steps for each sequence in
the minibatch.
compute_grad
If set to True, enables the computation of gradients of the CTC loss function.
Returns
-------
1-D tensor
Cost of each example in the minibatch. Tensor is of shape
(time index, minibatch index, probabilities).
Op
An instance of the CTC loss computation Op
"""
__props__ = ('compute_grad',)
......@@ -134,6 +122,36 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
def ctc(activations, labels, input_lengths):
"""
Compute CTC loss function.
Notes
-----
Using the wrapper requires that Baidu's warp-ctc library is installed and the
configuration variables `config.ctc.enabled` and `config.ctc.root` be properly
set.
Parameters
----------
activations
Three-dimensional tensor, which has a shape of (t, m, p), where
t is the time index, m is the minibatch index, and p is the index
over the probabilities of each symbol in the alphabet. The memory
layout is assumed to be in C-order, which consists in the slowest
to the fastest changing dimension, from left to right. In this case,
p is the fastest changing dimension.
labels
A 1-D tensor of all the labels for the minibatch.
input_lengths
A 1-D tensor with the number of time steps for each sequence in
the minibatch.
Returns
-------
1-D tensor
Cost of each example in the minibatch. Tensor is of shape
(time index, minibatch index, probabilities).
"""
return ConnectionistTemporalClassification()(activations, labels, input_lengths)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论