Fix CUDA synchronization in spatialtf_sampler

上级 a2fd5048
......@@ -190,6 +190,10 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1;
}
cuda_wait( input->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_wait( grid->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_wait( (*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE );
const void * input_data = PyGpuArray_DEV_DATA( input );
const void * grid_data = PyGpuArray_DEV_DATA( grid );
void * out_data = PyGpuArray_DEV_DATA( *output );
......@@ -197,6 +201,10 @@ spatialtf_sampler(PyGpuArrayObject * input,
err = cudnnSpatialTfSamplerForward( _handle, desc, alpha_p, spatialtf_ctx.xdesc,
input_data, grid_data, beta_p, spatialtf_ctx.ydesc, out_data );
cuda_record( input->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_record( grid->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_record( (*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE );
if ( CUDNN_STATUS_SUCCESS != err )
{
spatialtf_context_destroy( &spatialtf_ctx );
......@@ -204,10 +212,6 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1;
}
cuda_record(input->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_record(grid->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_record((*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
spatialtf_context_destroy( &spatialtf_ctx );
cuda_exit( gpu_ctx->ctx );
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论