Add grid coordinates memory allocation to spatialtf_grid

上级 b8c83c58
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
int spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc, int spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc,
PyGpuArrayObject * theta, PyGpuArrayObject * theta,
PyGpuArrayObject * num_dimensions, PyArrayObject * dimensions,
PyGpuArrayObject ** grid, PyGpuArrayObject ** grid,
cudnnHandle_t _handle) cudnnHandle_t _handle)
{ {
PyGpuContextObject * gpu_ctx = theta->context;
cudnnDataType_t dt; cudnnDataType_t dt;
cudnnStatus_t err; cudnnStatus_t err;
...@@ -26,12 +27,19 @@ int spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc, ...@@ -26,12 +27,19 @@ int spatialtf_grid(cudnnSpatialTransformerDescriptor_t desc,
return -1; return -1;
} }
switch( num_dimensions->ga.typecode ) if ( NULL == *grid )
{ {
case GA_INT: npy_int * dimensions_data = (npy_int *)PyArray_DATA( dimensions );
break; const size_t grid_dims[4] = { dimensions_data[0], dimensions_data[1],
default: dimensions_data[2], dimensions_data[3] };
PyErr_SetString( PyExc_TypeError, "Unsupported data type for the number of dimensions" ); *grid = pygpu_empty( 4, &(grid_dims[0]), theta->ga.typecode, GA_C_ORDER,
gpu_ctx, Py_None );
if ( *grid == NULL )
{
PyErr_Format( PyExc_MemoryError,
"Could not allocate memory for grid coordinates" );
return 1;
}
} }
return 0; return 0;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论