Fix return value of ctc_check_result in ctc wrapper

上级 407758bc
#section support_code #section support_code
int ctc_check_result(ctcStatus_t retcode, const char * msg, int * status) int ctc_check_result(ctcStatus_t retcode, const char * msg)
{ {
if( CTC_STATUS_SUCCESS != retcode ) if( CTC_STATUS_SUCCESS != retcode )
{ {
...@@ -11,9 +11,9 @@ int ctc_check_result(ctcStatus_t retcode, const char * msg, int * status) ...@@ -11,9 +11,9 @@ int ctc_check_result(ctcStatus_t retcode, const char * msg, int * status)
"%s | CTC library error message: %s", "%s | CTC library error message: %s",
msg, msg,
ctc_msg ); ctc_msg );
*status = 1; return 1;
} }
*status = 0; return 0;
} }
void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr, void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr,
...@@ -32,7 +32,7 @@ void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr, ...@@ -32,7 +32,7 @@ void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr,
} }
} }
void create_flat_labels( PyArrayObject * label_matrix, int ** flat_labels, void create_flat_labels( PyArrayObject * label_matrix, int ** flat_labels,
int ** label_lengths ) int ** label_lengths )
{ {
int rows = PyArray_DIMS( label_matrix )[0]; int rows = PyArray_DIMS( label_matrix )[0];
...@@ -93,7 +93,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -93,7 +93,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
{ {
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not create a contiguous copy of activations array." ); "Could not create a contiguous copy of activations array." );
return 1; return 1;
} }
} }
...@@ -107,7 +107,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -107,7 +107,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
{ {
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for input lengths" ); "Could not allocate storage for input lengths" );
return 1; return 1;
} }
// flatten labels to conform with library memory layout // flatten labels to conform with library memory layout
...@@ -133,7 +133,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -133,7 +133,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
PyArray_NDIM( *out_costs ) != 1 || // or, matrix has the wrong size PyArray_NDIM( *out_costs ) != 1 || // or, matrix has the wrong size
PyArray_DIMS( *out_costs )[0] != cost_size ) PyArray_DIMS( *out_costs )[0] != cost_size )
{ {
Py_XDECREF( *out_costs ); Py_XDECREF( *out_costs );
// Allocate new matrix // Allocate new matrix
*out_costs = (PyArrayObject *) PyArray_ZEROS( 1, &cost_size, NPY_FLOAT32, 0 ); *out_costs = (PyArrayObject *) PyArray_ZEROS( 1, &cost_size, NPY_FLOAT32, 0 );
...@@ -143,7 +143,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -143,7 +143,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
// labels // labels
free( input_lengths ); free( input_lengths );
free( label_lengths ); free( label_lengths );
free( flat_labels ); free( flat_labels );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for CTC costs" ); "Could not allocate storage for CTC costs" );
...@@ -161,7 +161,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -161,7 +161,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
{ {
// Existing matrix is the wrong size. Make a new one. // Existing matrix is the wrong size. Make a new one.
// Decrement ref counter to existing array // Decrement ref counter to existing array
Py_XDECREF( *out_gradients ); Py_XDECREF( *out_gradients );
// Allocate new array // Allocate new array
*out_gradients = (PyArrayObject *) PyArray_ZEROS(3, PyArray_DIMS( in_activations ), *out_gradients = (PyArrayObject *) PyArray_ZEROS(3, PyArray_DIMS( in_activations ),
NPY_FLOAT32, 0); NPY_FLOAT32, 0);
...@@ -177,7 +177,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -177,7 +177,7 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Could not allocate storage for CTC gradients!" ); "Could not allocate storage for CTC gradients!" );
return 1; return 1;
} }
} }
...@@ -196,10 +196,9 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -196,10 +196,9 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
size_t cpu_workspace_size; size_t cpu_workspace_size;
int ctc_error; int ctc_error;
ctc_check_result( get_workspace_size( label_lengths, input_lengths, ctc_error = ctc_check_result( get_workspace_size( label_lengths, input_lengths,
alphabet_size, minibatch_size, ctc_options, &cpu_workspace_size ), alphabet_size, minibatch_size, ctc_options, &cpu_workspace_size ),
"Failed to obtain CTC workspace size!", "Failed to obtain CTC workspace size!" );
&ctc_error );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1; return 1;
...@@ -214,19 +213,18 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -214,19 +213,18 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
free( label_lengths ); free( label_lengths );
free( flat_labels ); free( flat_labels );
Py_XDECREF( *out_costs ); Py_XDECREF( *out_costs );
Py_XDECREF( *out_gradients ); Py_XDECREF( *out_gradients );
PyErr_Format( PyExc_MemoryError, PyErr_Format( PyExc_MemoryError,
"Failed to allocate memory for CTC workspace!" ); "Failed to allocate memory for CTC workspace!" );
return 1; return 1;
} }
ctc_check_result( compute_ctc_loss( activations, gradients, flat_labels, ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients, flat_labels,
label_lengths, input_lengths, alphabet_size, minibatch_size, costs, label_lengths, input_lengths, alphabet_size, minibatch_size, costs,
ctc_cpu_workspace, ctc_options ), ctc_cpu_workspace, ctc_options ),
"Failed to compute CTC loss function!", "Failed to compute CTC loss function!" );
&ctc_error );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1; return 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论