Add ctc_context_t and its auxiliary functions to handle the CTC context

上级 fc61cad5
#section support_code #section support_code
typedef struct ctc_context {
struct ctcOptions options;
void * workspace;
int * input_lengths;
int * flat_labels;
int * label_lengths;
PyArrayObject * activations_copy;
} ctc_context_t;
void ctc_context_init(ctc_context_t * context)
{
struct ctcOptions options = context->options;
memset(&options, 0, sizeof(struct ctcOptions));
options.loc = CTC_CPU;
options.num_threads = 1;
context->workspace = NULL;
context->input_lengths = NULL;
context->flat_labels = NULL;
context->label_lengths = NULL;
context->activations_copy = NULL;
}
void ctc_context_destroy(ctc_context_t * context)
{
if ( NULL != context->workspace )
free( context->workspace );
if ( NULL != context->input_lengths )
free( context->input_lengths );
if ( NULL != context->flat_labels )
free( context->flat_labels );
if ( NULL != context->label_lengths )
free( context->label_lengths );
Py_XDECREF( context->activations_copy );
}
int ctc_check_result(ctcStatus_t retcode, const char * msg) int ctc_check_result(ctcStatus_t retcode, const char * msg)
{ {
if( CTC_STATUS_SUCCESS != retcode ) if( CTC_STATUS_SUCCESS != retcode )
...@@ -75,8 +115,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -75,8 +115,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
PyArrayObject ** out_costs, PyArrayObject ** out_costs,
PyArrayObject ** out_gradients) PyArrayObject ** out_gradients)
{ {
ctc_context_t context;
ctc_context_init( &context );
npy_float32 * activations = NULL; npy_float32 * activations = NULL;
PyArrayObject * activations_copy = NULL;
if ( PyArray_IS_C_CONTIGUOUS( in_activations ) ) if ( PyArray_IS_C_CONTIGUOUS( in_activations ) )
{ {
...@@ -84,10 +126,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -84,10 +126,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
} }
else else
{ {
activations_copy = PyArray_GETCONTIGUOUS( in_activations ); context.activations_copy = PyArray_GETCONTIGUOUS( in_activations );
if ( NULL != activations_copy ) if ( NULL != context.activations_copy )
{ {
activations = (npy_float32 *) PyArray_DATA( activations_copy ); activations = (npy_float32 *) PyArray_DATA( context.activations_copy );
} }
else else
{ {
...@@ -97,13 +139,9 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -97,13 +139,9 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
} }
} }
int * input_lengths = NULL, create_contiguous_input_lengths( in_input_lengths, &(context.input_lengths) );
* flat_labels = NULL,
* label_lengths = NULL;
create_contiguous_input_lengths( in_input_lengths, &input_lengths ); if ( NULL == context.input_lengths )
if ( NULL == input_lengths )
{ {
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for input lengths" ); "Could not allocate storage for input lengths" );
...@@ -111,12 +149,12 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -111,12 +149,12 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
} }
// flatten labels to conform with library memory layout // flatten labels to conform with library memory layout
create_flat_labels( in_labels, &flat_labels, &label_lengths ); create_flat_labels( in_labels, &(context.flat_labels), &(context.label_lengths) );
if ( ( NULL == label_lengths ) || ( NULL == flat_labels ) ) if ( ( NULL == context.label_lengths ) || ( NULL == context.flat_labels ) )
{ {
// Free previously allocated memory for input lengths // Destroy previous CTC context before returning exception
free( input_lengths ); ctc_context_destroy( &context );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for labels and their lengths" ); "Could not allocate storage for labels and their lengths" );
...@@ -139,11 +177,8 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -139,11 +177,8 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
if ( NULL == (*out_costs) ) if ( NULL == (*out_costs) )
{ {
// Free previously allocated memory for input and label lengths, and // Destroy previous CTC context before returning exception
// labels ctc_context_destroy( &context );
free( input_lengths );
free( label_lengths );
free( flat_labels );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for CTC costs" ); "Could not allocate storage for CTC costs" );
...@@ -168,12 +203,8 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -168,12 +203,8 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
if ( NULL == (*out_gradients) ) if ( NULL == (*out_gradients) )
{ {
// Free previously allocated memory for input and label lengths, // Destroy previous CTC context before returning exception
// labels and output costs ctc_context_destroy( &context );
free( input_lengths );
free( label_lengths );
free( flat_labels );
Py_XDECREF( *out_costs );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for CTC gradients!" ); "Could not allocate storage for CTC gradients!" );
...@@ -183,54 +214,38 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -183,54 +214,38 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
npy_float32 * gradients = (npy_float32 *) PyArray_DATA( *out_gradients ); npy_float32 * gradients = (npy_float32 *) PyArray_DATA( *out_gradients );
// setup CTC computation parameters
ctcOptions ctc_options;
memset( &ctc_options, 0, sizeof(ctcOptions) );
ctc_options.loc = CTC_CPU;
ctc_options.num_threads = 1;
size_t cpu_workspace_size; size_t cpu_workspace_size;
int ctc_error; int ctc_error;
ctc_error = ctc_check_result( get_workspace_size( label_lengths, input_lengths, ctc_error = ctc_check_result( get_workspace_size( context.label_lengths,
alphabet_size, minibatch_size, ctc_options, &cpu_workspace_size ), context.input_lengths, alphabet_size, minibatch_size, context.options,
&cpu_workspace_size ),
"Failed to obtain CTC workspace size!" ); "Failed to obtain CTC workspace size!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1; return 1;
void * ctc_cpu_workspace = malloc( cpu_workspace_size ); context.workspace = malloc( cpu_workspace_size );
if ( NULL == ctc_cpu_workspace ) if ( NULL == context.workspace )
{ {
// Free previously allocated memory for input and label lengths, // Destroy previous CTC context before returning exception
// labels, output costs and gradients ctc_context_destroy( &context );
free( input_lengths );
free( label_lengths );
free( flat_labels );
Py_XDECREF( *out_costs );
Py_XDECREF( *out_gradients );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Failed to allocate memory for CTC workspace!" ); "Failed to allocate memory for CTC workspace!" );
return 1; return 1;
} }
ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients, flat_labels, ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients,
label_lengths, input_lengths, alphabet_size, minibatch_size, costs, context.flat_labels, context.label_lengths, context.input_lengths,
ctc_cpu_workspace, ctc_options ), alphabet_size, minibatch_size, costs, context.workspace, context.options ),
"Failed to compute CTC loss function!" ); "Failed to compute CTC loss function!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1; return 1;
Py_XDECREF( activations_copy ); ctc_context_destroy( &context );
free( input_lengths );
free( flat_labels );
free( label_lengths );
free( ctc_cpu_workspace );
return 0; return 0;
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论