Check if output dimensions are correct in spatialtf_sampler

上级 9f8512d2
......@@ -152,9 +152,12 @@ spatialtf_sampler(PyGpuArrayObject * input,
return -1;
}
if ( NULL == *output )
const size_t out_dims[4] = { num_images, num_channels, height, width };
if ( NULL == *output ||
! theano_size_check( *output, 4, &(out_dims[0]), (*output)->ga.typecode ) )
{
const size_t out_dims[4] = { num_images, num_channels, height, width };
Py_XDECREF( *output );
*output = pygpu_zeros( 4, &(out_dims[0]), input->ga.typecode, GA_C_ORDER,
gpu_ctx, Py_None );
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论