Add grid of coordinates' instantiation in spatialtf_grid

上级 9133c74c
......@@ -29,18 +29,37 @@ int spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc,
if ( NULL == *grid )
{
// Obtain grid dimensions
npy_int * dimensions_data = (npy_int *)PyArray_DATA( dimensions );
const size_t grid_dims[4] = { dimensions_data[0], dimensions_data[1],
dimensions_data[2], dimensions_data[3] };
const size_t width = dimensions_data[0];
const size_t height = dimensions_data[1];
const size_t num_images = dimensions_data[3];
// Grid of coordinates is of size num_images * height * width * 2 for a 2D transformation
const size_t grid_dims[4] = { width, height, 2, num_images };
*grid = pygpu_empty( 4, &(grid_dims[0]), theta->ga.typecode, GA_C_ORDER,
gpu_ctx, Py_None );
if ( *grid == NULL )
if ( NULL == *grid )
{
PyErr_Format( PyExc_MemoryError,
"Could not allocate memory for grid coordinates" );
"Could not allocate memory for grid coordinates" );
return 1;
}
}
const void * theta_data = PyGpuArray_DEV_DATA( theta );
void * grid_data = PyGpuArray_DEV_DATA( *grid );
err = cudnnSpatialTfGridGeneratorForward( _handle, desc, theta_data, grid_data );
if ( CUDNN_STATUS_SUCCESS != err )
{
PyErr_Format( PyExc_RuntimeError,
"Failed to create grid of coordinates: %s",
cudnnGetErrorString( err ) );
return 1;
}
return 0;
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论