Add more checks and fix data layout in spatialtf grid generator

上级 e2bc4954
...@@ -18,33 +18,49 @@ spatialtf_grid(PyGpuArrayObject * theta, ...@@ -18,33 +18,49 @@ spatialtf_grid(PyGpuArrayObject * theta,
return -1; return -1;
} }
if ( PyArray_DIM( grid_dimensions, 0 ) < 4 ) if ( PyGpuArray_NDIM( theta ) != 3 )
{ {
PyErr_Format( PyExc_RuntimeError, PyErr_Format( PyExc_RuntimeError,
"Grid dimensions array must have at least 4 dimensions!" ); "theta must have three dimensions!" );
return -1;
}
if ( PyGpuArray_DIM( theta, 1 ) != 2 && PyGpuArray_DIM( theta, 2 ) != 3 )
{
PyErr_Format( PyExc_RuntimeError,
"Incorrect dimensions for theta, should be (%d, %d, %d), got (%d, %d, %d)",
PyGpuArray_DIMS( theta )[0], 2, 3, PyGpuArray_DIMS( theta )[0],
PyGpuArray_DIMS( theta )[1], PyGpuArray_DIMS( theta )[2] );
return -1;
}
if ( PyArray_DIM( grid_dimensions, 0 ) < 3 )
{
PyErr_Format( PyExc_RuntimeError,
"Grid dimensions array must have at least 3 dimensions!" );
return -1; return -1;
} }
// Obtain grid dimensions // Obtain grid dimensions
npy_int * dimensions_data = (npy_int *)PyArray_DATA( grid_dimensions ); npy_int * dimensions_data = (npy_int *)PyArray_DATA( grid_dimensions );
const size_t width = dimensions_data[0]; const size_t num_images = dimensions_data[0];
const size_t height = dimensions_data[1]; const size_t height = dimensions_data[1];
const size_t num_images = dimensions_data[3]; const size_t width = dimensions_data[2];
// Grid of coordinates is of size num_images * height * width * 2 for a 2D transformation // Grid of coordinates is of size num_images * height * width * 2 for a 2D transformation
const size_t grid_dims[4] = { width, height, 2, num_images }; const size_t grid_dims[4] = { num_images, height, width, 2 };
if ( width == 0 || height == 0 || num_images == 0 ) if ( width == 0 || height == 0 || num_images == 0 )
{ {
PyErr_Format( PyExc_RuntimeError, PyErr_Format( PyExc_RuntimeError,
"Invalid grid dimensions!" ); "One of the grid dimensions is zero" );
return -1; return -1;
} }
if ( NULL == *grid ) if ( NULL == *grid )
{ {
*grid = pygpu_zeros( 4, &(grid_dims[0]), theta->ga.typecode, GA_C_ORDER, gpu_ctx, Py_None ); *grid = pygpu_zeros( 4, &(grid_dims[0]), theta->ga.typecode, GA_C_ORDER,
gpu_ctx, Py_None );
if ( NULL == *grid ) if ( NULL == *grid )
{ {
PyErr_SetString( PyExc_MemoryError, PyErr_SetString( PyExc_MemoryError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论