Fix data layout in spatialtf_sampler, and add cuda synchronization

上级 e350433d
......@@ -24,7 +24,6 @@ void spatialtf_context_destroy( spatialtf_context_t * ctx )
int
spatialtf_sampler(PyGpuArrayObject * input,
PyGpuArrayObject * om,
PyGpuArrayObject * grid,
PyArrayObject * grid_dimensions,
cudnnSpatialTransformerDescriptor_t desc,
......@@ -38,14 +37,16 @@ spatialtf_sampler(PyGpuArrayObject * input,
float af = alpha, bf = beta;
spatialtf_context_t spatialtf_ctx;
cudnnDataType_t dt;
// Number of color channels (feature maps) is the innermost dimension
cudnnTensorFormat_t tf = CUDNN_TENSOR_NHWC;
cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
// Obtain grid dimensions
npy_int * dimensions_data = (npy_int *)PyArray_DATA( grid_dimensions );
const int width = dimensions_data[0];
const int num_images = dimensions_data[0];
const int height = dimensions_data[1];
const int num_channels = dimensions_data[2];
const int num_images = dimensions_data[3];
const int width = dimensions_data[2];
const int num_channels = dimensions_data[3];
switch (grid->ga.typecode)
{
......@@ -65,8 +66,8 @@ spatialtf_sampler(PyGpuArrayObject * input,
dt = CUDNN_DATA_HALF;
break;
default:
PyErr_SetString(PyExc_TypeError,
"Unsupported type in spatial transformer sampler");
PyErr_SetString( PyExc_TypeError,
"Unsupported type in spatial transformer sampler" );
return -1;
}
......@@ -87,8 +88,8 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1;
}
err = cudnnSetTensor4dDescriptor( spatialtf_ctx.xdesc, CUDNN_TENSOR_NCHW, dt,
num_images, num_channels, height, width );
err = cudnnSetTensor4dDescriptor( spatialtf_ctx.xdesc, tf, dt, num_images,
num_channels, height, width );
if ( err != CUDNN_STATUS_SUCCESS )
{
......@@ -114,8 +115,8 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1;
}
err = cudnnSetTensor4dDescriptor( spatialtf_ctx.ydesc, CUDNN_TENSOR_NCHW, dt,
num_images, num_channels, height, width );
err = cudnnSetTensor4dDescriptor( spatialtf_ctx.ydesc, tf, dt, num_images,
num_channels, height, width );
if ( err != CUDNN_STATUS_SUCCESS )
{
......@@ -130,8 +131,11 @@ spatialtf_sampler(PyGpuArrayObject * input,
if ( NULL == *output )
{
*output = pygpu_zeros( PyGpuArray_NDIM(om), PyGpuArray_DIMS(om), input->ga.typecode,
GA_C_ORDER, gpu_ctx, Py_None );
// (num_images, height, width, num_channels )
const size_t out_dims[4] = { num_images, height, width, num_channels };
*output = pygpu_zeros( 4, &(out_dims[0]), input->ga.typecode, GA_C_ORDER,
gpu_ctx, Py_None );
if ( NULL == *output )
{
......@@ -145,8 +149,8 @@ spatialtf_sampler(PyGpuArrayObject * input,
}
const void * input_data = PyGpuArray_DEV_DATA( input );
const void * grid_data = PyGpuArray_DEV_DATA( grid );
void * out_data = PyGpuArray_DEV_DATA( *output );
const void * grid_data = PyGpuArray_DEV_DATA( grid );
void * out_data = PyGpuArray_DEV_DATA( *output );
err = cudnnSpatialTfSamplerForward( _handle, desc, alpha_p, spatialtf_ctx.xdesc,
input_data, grid_data, beta_p, spatialtf_ctx.ydesc, out_data );
......@@ -158,6 +162,10 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1;
}
cuda_record(input->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_record(grid->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_record((*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
spatialtf_context_destroy( &spatialtf_ctx );
cuda_exit( gpu_ctx->ctx );
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论