提交 d9d2adce authored 作者: João Victor Risso's avatar João Victor Risso

Rename grid_dims to out_dims in dnn_sptf_grid

上级 c02d556a
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
int int
APPLY_SPECIFIC(dnn_sptf_grid)(PyGpuArrayObject * theta, APPLY_SPECIFIC(dnn_sptf_grid)(PyGpuArrayObject * theta,
PyArrayObject * grid_dims, PyArrayObject * out_dims,
cudnnSpatialTransformerDescriptor_t desc, cudnnSpatialTransformerDescriptor_t desc,
PyGpuArrayObject ** grid, PyGpuArrayObject ** grid,
cudnnHandle_t _handle) cudnnHandle_t _handle)
{ {
PyGpuContextObject * gpu_ctx = theta->context; PyGpuContextObject * gpu_ctx = theta->context;
size_t gpu_grid_dims[4]; size_t grid_dims[4];
int num_images, num_channels, height, width; int num_images, num_channels, height, width;
cudnnStatus_t err = CUDNN_STATUS_SUCCESS; cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
...@@ -20,7 +20,7 @@ APPLY_SPECIFIC(dnn_sptf_grid)(PyGpuArrayObject * theta, ...@@ -20,7 +20,7 @@ APPLY_SPECIFIC(dnn_sptf_grid)(PyGpuArrayObject * theta,
"GpuDnnTransformerGrid: unsupported data type for theta in spatial transformer." ); "GpuDnnTransformerGrid: unsupported data type for theta in spatial transformer." );
return 1; return 1;
} }
else if ( PyGpuArray_DIM( theta, 1 ) != 2 && PyGpuArray_DIM( theta, 2 ) != 3 ) else if ( PyGpuArray_DIM( theta, 1 ) != 2 || PyGpuArray_DIM( theta, 2 ) != 3 )
{ {
PyErr_Format( PyExc_RuntimeError, PyErr_Format( PyExc_RuntimeError,
"GpuDnnTransformerGrid: incorrect dimensions for theta, expected (%d, %d, %d), got (%d, %d, %d)", "GpuDnnTransformerGrid: incorrect dimensions for theta, expected (%d, %d, %d), got (%d, %d, %d)",
...@@ -29,24 +29,24 @@ APPLY_SPECIFIC(dnn_sptf_grid)(PyGpuArrayObject * theta, ...@@ -29,24 +29,24 @@ APPLY_SPECIFIC(dnn_sptf_grid)(PyGpuArrayObject * theta,
return 1; return 1;
} }
if ( PyArray_NDIM( grid_dims ) != 1 || PyArray_SIZE( grid_dims ) != 4 ) if ( PyArray_NDIM( out_dims ) != 1 || PyArray_SIZE( out_dims ) != 4 )
{ {
PyErr_SetString( PyExc_MemoryError, PyErr_SetString( PyExc_MemoryError,
"GpuDnnTransformerGrid: grid_dims must have 4 elements." ); "GpuDnnTransformerGrid: out_dims must have 4 elements." );
return 1; return 1;
} }
// Obtain grid dimensions // Obtain output dimensions
num_images = (int) *( (npy_int64 *) PyArray_GETPTR1( grid_dims, 0 ) ); num_images = (int) *( (npy_int64 *) PyArray_GETPTR1( out_dims, 0 ) );
height = (int) *( (npy_int64 *) PyArray_GETPTR1( grid_dims, 2 ) ); height = (int) *( (npy_int64 *) PyArray_GETPTR1( out_dims, 2 ) );
width = (int) *( (npy_int64 *) PyArray_GETPTR1( grid_dims, 3 ) ); width = (int) *( (npy_int64 *) PyArray_GETPTR1( out_dims, 3 ) );
// Set grid dimensions
grid_dims[0] = num_images;
grid_dims[1] = height;
grid_dims[2] = width;
grid_dims[3] = 2;
gpu_grid_dims[0] = num_images; if ( theano_prep_output( grid, 4, grid_dims, theta->ga.typecode,
gpu_grid_dims[1] = height;
gpu_grid_dims[2] = width;
gpu_grid_dims[3] = 2;
if ( theano_prep_output( grid, 4, gpu_grid_dims, theta->ga.typecode,
GA_C_ORDER, gpu_ctx ) != 0 ) GA_C_ORDER, gpu_ctx ) != 0 )
{ {
PyErr_SetString( PyExc_RuntimeError, PyErr_SetString( PyExc_RuntimeError,
...@@ -76,4 +76,5 @@ APPLY_SPECIFIC(dnn_sptf_grid)(PyGpuArrayObject * theta, ...@@ -76,4 +76,5 @@ APPLY_SPECIFIC(dnn_sptf_grid)(PyGpuArrayObject * theta,
} }
return 0; return 0;
} }
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论