Fix setup of input tensor in spatialtf_sampler

上级 a59c408d
...@@ -94,8 +94,25 @@ spatialtf_sampler(PyGpuArrayObject * input, ...@@ -94,8 +94,25 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1; return -1;
} }
// In the input tensor, we must use its width and height, instead
// of the grid's width and height. The number of images and channels
// should be the same as the grid dimensions
const int input_num_images = (int) PyGpuArray_DIM( input, 0 );
const int input_height = (int) PyGpuArray_DIM( input, 1 );
const int input_width = (int) PyGpuArray_DIM( input, 2 );
const int input_num_channels = (int) PyGpuArray_DIM( input, 3 );
if ( input_num_images != num_images ||
input_num_channels != num_channels )
{
PyErr_Format( PyExc_RuntimeError,
"Input should have %d images and %d channels, got %d images and %d channels.",
num_images, num_channels, input_num_images, input_num_channels );
return -1;
}
err = cudnnSetTensor4dDescriptor( spatialtf_ctx.xdesc, tf, dt, num_images, err = cudnnSetTensor4dDescriptor( spatialtf_ctx.xdesc, tf, dt, num_images,
num_channels, height, width ); num_channels, input_height, input_width );
if ( err != CUDNN_STATUS_SUCCESS ) if ( err != CUDNN_STATUS_SUCCESS )
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论