Change type of labels and lengths to int *

Change type of input_lengths, flat_labels, label_lengths to int * in CTC C wrapper, in order to maintain compatibility with Baidu CTC's API types, regardless of platform, operating system and so on. Signed-off-by: 's avatarJoão Victor Tozatti Risso <joaovictor.risso@gmail.com>
上级 95c17e0d
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
typedef struct ctc_context { typedef struct ctc_context {
struct ctcOptions options; struct ctcOptions options;
void * workspace; void * workspace;
npy_int * input_lengths; int * input_lengths;
npy_int * flat_labels; int * flat_labels;
npy_int * label_lengths; int * label_lengths;
} ctc_context_t; } ctc_context_t;
void ctc_context_init(ctc_context_t * context) void ctc_context_init(ctc_context_t * context)
...@@ -53,11 +53,11 @@ int ctc_check_result(ctcStatus_t retcode, const char * msg) ...@@ -53,11 +53,11 @@ int ctc_check_result(ctcStatus_t retcode, const char * msg)
} }
void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr, void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr,
npy_int ** input_lengths ) int ** input_lengths )
{ {
npy_int num_elements = PyArray_DIMS( input_lengths_arr )[0]; npy_int num_elements = PyArray_DIMS( input_lengths_arr )[0];
*input_lengths = (npy_int *) malloc( num_elements * sizeof(npy_int) ); *input_lengths = (int *) malloc( num_elements * sizeof(int) );
if ( NULL == (*input_lengths) ) if ( NULL == (*input_lengths) )
return; return;
...@@ -68,17 +68,17 @@ void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr, ...@@ -68,17 +68,17 @@ void create_contiguous_input_lengths( PyArrayObject * input_lengths_arr,
} }
} }
void create_flat_labels( PyArrayObject * label_matrix, npy_int ** flat_labels, void create_flat_labels( PyArrayObject * label_matrix, int ** flat_labels,
npy_int ** label_lengths ) int ** label_lengths )
{ {
npy_int rows = PyArray_DIMS( label_matrix )[0]; npy_int rows = PyArray_DIMS( label_matrix )[0];
npy_int cols = PyArray_DIMS( label_matrix )[1]; npy_int cols = PyArray_DIMS( label_matrix )[1];
*flat_labels = (npy_int *) malloc( rows * cols * sizeof(npy_int) ); *flat_labels = (int *) malloc( rows * cols * sizeof(int) );
if ( NULL == (*flat_labels) ) if ( NULL == (*flat_labels) )
return; return;
*label_lengths = (npy_int *) malloc( rows * sizeof(npy_int) ); *label_lengths = (int *) malloc( rows * sizeof(int) );
if ( NULL == (*label_lengths) ) if ( NULL == (*label_lengths) )
{ {
free( *flat_labels ); free( *flat_labels );
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论