Fix memory initialization for grid of coordinates in spatialtf_grid

上级 1c155e5f
#section support_code #section support_code
int int
spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc, spatialtf_grid(PyGpuArrayObject * theta,
PyGpuArrayObject * theta,
PyArrayObject * grid_dimensions, PyArrayObject * grid_dimensions,
cudnnSpatialTransformerDescriptor_t desc,
PyGpuArrayObject ** grid, PyGpuArrayObject ** grid,
cudnnHandle_t _handle) cudnnHandle_t _handle)
{ {
...@@ -18,7 +18,7 @@ spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc, ...@@ -18,7 +18,7 @@ spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc,
return -1; return -1;
} }
if ( PyArray_NDIM( grid_dimensions ) < 4 ) if ( PyArray_DIM( grid_dimensions, 0 ) < 4 )
{ {
PyErr_Format( PyExc_RuntimeError, PyErr_Format( PyExc_RuntimeError,
"Grid dimensions array must have at least 4 dimensions!" ); "Grid dimensions array must have at least 4 dimensions!" );
...@@ -34,14 +34,25 @@ spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc, ...@@ -34,14 +34,25 @@ spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc,
// Grid of coordinates is of size num_images * height * width * 2 for a 2D transformation // Grid of coordinates is of size num_images * height * width * 2 for a 2D transformation
const size_t grid_dims[4] = { width, height, 2, num_images }; const size_t grid_dims[4] = { width, height, 2, num_images };
if ( theano_prep_output( grid, 4, &(grid_dims[0]), theta->ga.typecode, if ( width == 0 || height == 0 || num_images == 0 )
GA_C_ORDER, gpu_ctx ) != 0 )
{ {
PyErr_SetString( PyExc_MemoryError, PyErr_Format( PyExc_RuntimeError,
"Could not allocate memory for the grid of coordinates" ); "Invalid grid dimensions!" );
return -1; return -1;
} }
if ( NULL == *grid )
{
*grid = pygpu_zeros( 4, &(grid_dims[0]), theta->ga.typecode, GA_C_ORDER, gpu_ctx, Py_None );
if ( NULL == *grid )
{
PyErr_SetString( PyExc_MemoryError,
"Could not allocate memory for grid of coordinates" );
return -1;
}
}
const void * theta_data = PyGpuArray_DEV_DATA( theta ); const void * theta_data = PyGpuArray_DEV_DATA( theta );
void * grid_data = PyGpuArray_DEV_DATA( *grid ); void * grid_data = PyGpuArray_DEV_DATA( *grid );
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论