Fix return value of ctc_check_result in ctc wrapper

上级 407758bc
#section support_code
int ctc_check_result(ctcStatus_t retcode, const char * msg, int * status)
int ctc_check_result(ctcStatus_t retcode, const char * msg)
{
if( CTC_STATUS_SUCCESS != retcode )
{
......@@ -11,9 +11,9 @@ int ctc_check_result(ctcStatus_t retcode, const char * msg, int * status)
"%s | CTC library error message: %s",
msg,
ctc_msg );
*status = 1;
return 1;
}
*status = 0;
return 0;
}
void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr,
......@@ -196,10 +196,9 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
size_t cpu_workspace_size;
int ctc_error;
ctc_check_result( get_workspace_size( label_lengths, input_lengths,
ctc_error = ctc_check_result( get_workspace_size( label_lengths, input_lengths,
alphabet_size, minibatch_size, ctc_options, &cpu_workspace_size ),
"Failed to obtain CTC workspace size!",
&ctc_error );
"Failed to obtain CTC workspace size!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1;
......@@ -221,11 +220,10 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
return 1;
}
ctc_check_result( compute_ctc_loss( activations, gradients, flat_labels,
ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients, flat_labels,
label_lengths, input_lengths, alphabet_size, minibatch_size, costs,
ctc_cpu_workspace, ctc_options ),
"Failed to compute CTC loss function!",
&ctc_error );
"Failed to compute CTC loss function!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here
return 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论