Add ctc_context_t and its auxiliary functions to handle the CTC context

上级 fc61cad5
#section support_code
typedef struct ctc_context {
struct ctcOptions options;
void * workspace;
int * input_lengths;
int * flat_labels;
int * label_lengths;
PyArrayObject * activations_copy;
} ctc_context_t;
void ctc_context_init(ctc_context_t * context)
{
struct ctcOptions options = context->options;
memset(&options, 0, sizeof(struct ctcOptions));
options.loc = CTC_CPU;
options.num_threads = 1;
context->workspace = NULL;
context->input_lengths = NULL;
context->flat_labels = NULL;
context->label_lengths = NULL;
context->activations_copy = NULL;
}
void ctc_context_destroy(ctc_context_t * context)
{
if ( NULL != context->workspace )
free( context->workspace );
if ( NULL != context->input_lengths )
free( context->input_lengths );
if ( NULL != context->flat_labels )
free( context->flat_labels );
if ( NULL != context->label_lengths )
free( context->label_lengths );
Py_XDECREF( context->activations_copy );
}
int ctc_check_result(ctcStatus_t retcode, const char * msg)
{
if( CTC_STATUS_SUCCESS != retcode )
......@@ -75,8 +115,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
PyArrayObject ** out_costs,
PyArrayObject ** out_gradients)
{
ctc_context_t context;
ctc_context_init( &context );
npy_float32 * activations = NULL;
PyArrayObject * activations_copy = NULL;
if ( PyArray_IS_C_CONTIGUOUS( in_activations ) )
{
......@@ -84,10 +126,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
}
else
{
activations_copy = PyArray_GETCONTIGUOUS( in_activations );
if ( NULL != activations_copy )
context.activations_copy = PyArray_GETCONTIGUOUS( in_activations );
if ( NULL != context.activations_copy )
{
activations = (npy_float32 *) PyArray_DATA( activations_copy );
activations = (npy_float32 *) PyArray_DATA( context.activations_copy );
}
else
{
......@@ -97,13 +139,9 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
}
}
int * input_lengths = NULL,
* flat_labels = NULL,
* label_lengths = NULL;
create_contiguous_input_lengths( in_input_lengths, &input_lengths );
create_contiguous_input_lengths( in_input_lengths, &(context.input_lengths) );
if ( NULL == input_lengths )
if ( NULL == context.input_lengths )
{
PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for input lengths" );
......@@ -111,12 +149,12 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
}
// flatten labels to conform with library memory layout
create_flat_labels( in_labels, &flat_labels, &label_lengths );
create_flat_labels( in_labels, &(context.flat_labels), &(context.label_lengths) );
if ( ( NULL == label_lengths ) || ( NULL == flat_labels ) )
if ( ( NULL == context.label_lengths ) || ( NULL == context.flat_labels ) )
{
// Free previously allocated memory for input lengths
free( input_lengths );
// Destroy previous CTC context before returning exception
ctc_context_destroy( &context );
PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for labels and their lengths" );
......@@ -139,11 +177,8 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
if ( NULL == (*out_costs) )
{
// Free previously allocated memory for input and label lengths, and
// labels
free( input_lengths );
free( label_lengths );
free( flat_labels );
// Destroy previous CTC context before returning exception
ctc_context_destroy( &context );
PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for CTC costs" );
......@@ -168,12 +203,8 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
if ( NULL == (*out_gradients) )
{
// Free previously allocated memory for input and label lengths,
// labels and output costs
free( input_lengths );
free( label_lengths );
free( flat_labels );
Py_XDECREF( *out_costs );
// Destroy previous CTC context before returning exception
ctc_context_destroy( &context );
PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for CTC gradients!" );
......@@ -183,54 +214,38 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
npy_float32 * gradients = (npy_float32 *) PyArray_DATA( *out_gradients );
// setup CTC computation parameters
ctcOptions ctc_options;
memset( &ctc_options, 0, sizeof(ctcOptions) );
ctc_options.loc = CTC_CPU;
ctc_options.num_threads = 1;
size_t cpu_workspace_size;
int ctc_error;
ctc_error = ctc_check_result( get_workspace_size( label_lengths, input_lengths,
alphabet_size, minibatch_size, ctc_options, &cpu_workspace_size ),
ctc_error = ctc_check_result( get_workspace_size( context.label_lengths,
context.input_lengths, alphabet_size, minibatch_size, context.options,
&cpu_workspace_size ),
"Failed to obtain CTC workspace size!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1;
void * ctc_cpu_workspace = malloc( cpu_workspace_size );
context.workspace = malloc( cpu_workspace_size );
if ( NULL == ctc_cpu_workspace )
if ( NULL == context.workspace )
{
// Free previously allocated memory for input and label lengths,
// labels, output costs and gradients
free( input_lengths );
free( label_lengths );
free( flat_labels );
Py_XDECREF( *out_costs );
Py_XDECREF( *out_gradients );
// Destroy previous CTC context before returning exception
ctc_context_destroy( &context );
PyErr_Format( PyExc_MemoryError,
"Failed to allocate memory for CTC workspace!" );
return 1;
}
ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients, flat_labels,
label_lengths, input_lengths, alphabet_size, minibatch_size, costs,
ctc_cpu_workspace, ctc_options ),
ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients,
context.flat_labels, context.label_lengths, context.input_lengths,
alphabet_size, minibatch_size, costs, context.workspace, context.options ),
"Failed to compute CTC loss function!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1;
Py_XDECREF( activations_copy );
free( input_lengths );
free( flat_labels );
free( label_lengths );
free( ctc_cpu_workspace );
ctc_context_destroy( &context );
return 0;
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论