Change access do grid_dimensions array to use PyArray_GETPTR1

上级 7350d62f
...@@ -34,19 +34,18 @@ spatialtf_grid(PyArrayObject * grid_dimensions, ...@@ -34,19 +34,18 @@ spatialtf_grid(PyArrayObject * grid_dimensions,
return -1; return -1;
} }
if ( PyArray_DIM( grid_dimensions, 0 ) < 3 ) if ( PyArray_DIM( grid_dimensions, 0 ) != 4 )
{ {
PyErr_Format( PyExc_RuntimeError, PyErr_Format( PyExc_RuntimeError,
"Grid dimensions array must have at least 3 dimensions!" ); "grid_dimensions must have 4 dimensions!" );
return -1; return -1;
} }
// Obtain grid dimensions // Obtain grid dimensions
npy_int * dimensions_data = (npy_int *)PyArray_DATA( grid_dimensions ); const size_t num_images = (size_t) *( (npy_int *) PyArray_GETPTR1( grid_dimensions, 0 ) );
const size_t height = (size_t) *( (npy_int *) PyArray_GETPTR1( grid_dimensions, 1 ) );
const size_t width = (size_t) *( (npy_int *) PyArray_GETPTR1( grid_dimensions, 2 ) );
const size_t num_images = dimensions_data[0];
const size_t height = dimensions_data[1];
const size_t width = dimensions_data[2];
// Grid of coordinates is of size num_images * height * width * 2 for a 2D transformation // Grid of coordinates is of size num_images * height * width * 2 for a 2D transformation
const size_t grid_dims[4] = { num_images, height, width, 2 }; const size_t grid_dims[4] = { num_images, height, width, 2 };
......
...@@ -41,12 +41,18 @@ spatialtf_sampler(PyGpuArrayObject * input, ...@@ -41,12 +41,18 @@ spatialtf_sampler(PyGpuArrayObject * input,
cudnnTensorFormat_t tf = CUDNN_TENSOR_NHWC; cudnnTensorFormat_t tf = CUDNN_TENSOR_NHWC;
cudnnStatus_t err = CUDNN_STATUS_SUCCESS; cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
if ( PyArray_DIM( grid_dimensions, 0 ) != 4 )
{
PyErr_SetString( PyExc_RuntimeError,
"grid_dimensions must have 4 dimensions" );
return -1;
}
// Obtain grid dimensions // Obtain grid dimensions
npy_int * dimensions_data = (npy_int *)PyArray_DATA( grid_dimensions ); const int num_images = (int) *( (npy_int *) PyArray_GETPTR1( grid_dimensions, 0 ) );
const int num_images = dimensions_data[0]; const int height = (int) *( (npy_int *) PyArray_GETPTR1( grid_dimensions, 1 ) );
const int height = dimensions_data[1]; const int width = (int) *( (npy_int *) PyArray_GETPTR1( grid_dimensions, 2 ) );
const int width = dimensions_data[2]; const int num_channels = (int) *( (npy_int *) PyArray_GETPTR1( grid_dimensions, 3 ) );
const int num_channels = dimensions_data[3];
switch (grid->ga.typecode) switch (grid->ga.typecode)
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论