Change CTC workspace type from PyGpuArrayObject to gpudata

上级 e74a2149
......@@ -2,7 +2,7 @@
typedef struct ctc_context {
struct ctcOptions options;
PyGpuArrayObject * workspace;
gpudata * workspace;
int * input_lengths;
int * flat_labels;
int * label_lengths;
......@@ -25,7 +25,7 @@ void ctc_context_init(ctc_context_t * context, PyGpuContextObject * gpu_context)
void ctc_context_destroy(ctc_context_t * context)
{
Py_XDECREF( context->workspace );
gpudata_release( context->workspace );
if ( NULL != context->input_lengths )
free( context->input_lengths );
......@@ -250,8 +250,7 @@ int APPLY_SPECIFIC(ctc_cost_gpu)(PyGpuArrayObject * in_activations,
return 1;
}
context->workspace = pygpu_empty(1, &gpu_workspace_size, GA_BYTE,
GA_C_ORDER, gpu_context, Py_None );
context->workspace = gpudata_alloc( gpu_context->ctx, gpu_workspace_size, NULL, 0, NULL );
if ( NULL == context->workspace )
{
......@@ -264,7 +263,7 @@ int APPLY_SPECIFIC(ctc_cost_gpu)(PyGpuArrayObject * in_activations,
ctc_error = ctc_check_result( compute_ctc_loss( activations, gradients,
context->flat_labels, context->label_lengths, context->input_lengths,
alphabet_size, minibatch_size, costs, PyGpuArray_DEV_DATA(context->workspace),
alphabet_size, minibatch_size, costs, *(void **)context->workspace,
context->options ), "Failed to compute CTC loss function!" );
if ( ctc_error ) // Exception is set by ctc_check_result, return error here
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论