Change CTC context allocation from stack to heap using malloc

上级 1069b835
...@@ -115,8 +115,8 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -115,8 +115,8 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
PyArrayObject ** out_costs, PyArrayObject ** out_costs,
PyArrayObject ** out_gradients) PyArrayObject ** out_gradients)
{ {
ctc_context_t context; ctc_context_t * context = (ctc_context_t *)malloc( sizeof( ctc_context_t ) );
ctc_context_init( &context ); ctc_context_init( context );
npy_float32 * activations = NULL; npy_float32 * activations = NULL;
...@@ -126,10 +126,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -126,10 +126,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
} }
else else
{ {
context.activations_copy = PyArray_GETCONTIGUOUS( in_activations ); context->activations_copy = PyArray_GETCONTIGUOUS( in_activations );
if ( NULL != context.activations_copy ) if ( NULL != context->activations_copy )
{ {
activations = (npy_float32 *) PyArray_DATA( context.activations_copy ); activations = (npy_float32 *) PyArray_DATA( context->activations_copy );
} }
else else
{ {
...@@ -139,9 +139,9 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -139,9 +139,9 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
} }
} }
create_contiguous_input_lengths( in_input_lengths, &(context.input_lengths) ); create_contiguous_input_lengths( in_input_lengths, &(context->input_lengths) );
if ( NULL == context.input_lengths ) if ( NULL == context->input_lengths )
{ {
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for input lengths" ); "Could not allocate storage for input lengths" );
...@@ -149,12 +149,12 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -149,12 +149,12 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
} }
// flatten labels to conform with library memory layout // flatten labels to conform with library memory layout
create_flat_labels( in_labels, &(context.flat_labels), &(context.label_lengths) ); create_flat_labels( in_labels, &(context->flat_labels), &(context->label_lengths) );
if ( ( NULL == context.label_lengths ) || ( NULL == context.flat_labels ) ) if ( ( NULL == context->label_lengths ) || ( NULL == context->flat_labels ) )
{ {
// Destroy previous CTC context before returning exception // Destroy previous CTC context before returning exception
ctc_context_destroy( &context ); ctc_context_destroy( context );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for labels and their lengths" ); "Could not allocate storage for labels and their lengths" );
...@@ -178,7 +178,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -178,7 +178,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
if ( NULL == (*out_costs) ) if ( NULL == (*out_costs) )
{ {
// Destroy previous CTC context before returning exception // Destroy previous CTC context before returning exception
ctc_context_destroy( &context ); ctc_context_destroy( context );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for CTC costs" ); "Could not allocate storage for CTC costs" );
...@@ -204,7 +204,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -204,7 +204,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
if ( NULL == (*out_gradients) ) if ( NULL == (*out_gradients) )
{ {
// Destroy previous CTC context before returning exception // Destroy previous CTC context before returning exception
ctc_context_destroy( &context ); ctc_context_destroy( context );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for CTC gradients!" ); "Could not allocate storage for CTC gradients!" );
...@@ -217,20 +217,20 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -217,20 +217,20 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
size_t cpu_workspace_size; size_t cpu_workspace_size;
int ctc_error; int ctc_error;
ctc_error = ctc_check_result( get_workspace_size( context.label_lengths, ctc_error = ctc_check_result( get_workspace_size( context->label_lengths,
context.input_lengths, alphabet_size, minibatch_size, context.options, context->input_lengths, alphabet_size, minibatch_size, context->options,
&cpu_workspace_size ), &cpu_workspace_size ),
"Failed to obtain CTC workspace size!" ); "Failed to obtain CTC workspace size!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1; return 1;
context.workspace = malloc( cpu_workspace_size ); context->workspace = malloc( cpu_workspace_size );
if ( NULL == context.workspace ) if ( NULL == context->workspace )
{ {
// Destroy previous CTC context before returning exception // Destroy previous CTC context before returning exception
ctc_context_destroy( &context ); ctc_context_destroy( context );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Failed to allocate memory for CTC workspace!" ); "Failed to allocate memory for CTC workspace!" );
...@@ -238,14 +238,15 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -238,14 +238,15 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
} }
ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients, ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients,
context.flat_labels, context.label_lengths, context.input_lengths, context->flat_labels, context->label_lengths, context->input_lengths,
alphabet_size, minibatch_size, costs, context.workspace, context.options ), alphabet_size, minibatch_size, costs, context->workspace, context->options ),
"Failed to compute CTC loss function!" ); "Failed to compute CTC loss function!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1; return 1;
ctc_context_destroy( &context ); ctc_context_destroy( context );
free( context );
return 0; return 0;
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论