Fix CUDA synchronization in spatialtf_sampler

上级 a2fd5048
...@@ -190,6 +190,10 @@ spatialtf_sampler(PyGpuArrayObject * input, ...@@ -190,6 +190,10 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1; return -1;
} }
cuda_wait( input->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_wait( grid->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_wait( (*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE );
const void * input_data = PyGpuArray_DEV_DATA( input ); const void * input_data = PyGpuArray_DEV_DATA( input );
const void * grid_data = PyGpuArray_DEV_DATA( grid ); const void * grid_data = PyGpuArray_DEV_DATA( grid );
void * out_data = PyGpuArray_DEV_DATA( *output ); void * out_data = PyGpuArray_DEV_DATA( *output );
...@@ -197,6 +201,10 @@ spatialtf_sampler(PyGpuArrayObject * input, ...@@ -197,6 +201,10 @@ spatialtf_sampler(PyGpuArrayObject * input,
err = cudnnSpatialTfSamplerForward( _handle, desc, alpha_p, spatialtf_ctx.xdesc, err = cudnnSpatialTfSamplerForward( _handle, desc, alpha_p, spatialtf_ctx.xdesc,
input_data, grid_data, beta_p, spatialtf_ctx.ydesc, out_data ); input_data, grid_data, beta_p, spatialtf_ctx.ydesc, out_data );
cuda_record( input->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_record( grid->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_record( (*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE );
if ( CUDNN_STATUS_SUCCESS != err ) if ( CUDNN_STATUS_SUCCESS != err )
{ {
spatialtf_context_destroy( &spatialtf_ctx ); spatialtf_context_destroy( &spatialtf_ctx );
...@@ -204,10 +212,6 @@ spatialtf_sampler(PyGpuArrayObject * input, ...@@ -204,10 +212,6 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1; return -1;
} }
cuda_record(input->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_record(grid->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_record((*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
spatialtf_context_destroy( &spatialtf_ctx ); spatialtf_context_destroy( &spatialtf_ctx );
cuda_exit( gpu_ctx->ctx ); cuda_exit( gpu_ctx->ctx );
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论