Fix CUDA synchronization in spatialtf_grid

上级 f1a7f50f
...@@ -86,11 +86,17 @@ spatialtf_grid(PyArrayObject * grid_dimensions, ...@@ -86,11 +86,17 @@ spatialtf_grid(PyArrayObject * grid_dimensions,
return -1; return -1;
} }
cuda_wait( theta->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_wait( (*grid)->ga.data, GPUARRAY_CUDA_WAIT_WRITE );
const void * theta_data = PyGpuArray_DEV_DATA( theta ); const void * theta_data = PyGpuArray_DEV_DATA( theta );
void * grid_data = PyGpuArray_DEV_DATA( *grid ); void * grid_data = PyGpuArray_DEV_DATA( *grid );
err = cudnnSpatialTfGridGeneratorForward( _handle, desc, theta_data, grid_data ); err = cudnnSpatialTfGridGeneratorForward( _handle, desc, theta_data, grid_data );
cuda_record( theta->ga.data, GPUARRAY_CUDA_WAIT_READ );
cuda_record( (*grid)->ga.data, GPUARRAY_CUDA_WAIT_WRITE );
if ( CUDNN_STATUS_SUCCESS != err ) if ( CUDNN_STATUS_SUCCESS != err )
{ {
PyErr_Format( PyExc_RuntimeError, PyErr_Format( PyExc_RuntimeError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论