提交 9ce9aa3e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Pass offsets for CorrMM3d.

上级 0641c6e3
#section kernels #section kernels
#kernel dilated_im3d2col_kernel : size, *, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, * : #kernel dilated_im3d2col_kernel : size, *, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, *, size :
// TODO check kernel flags // TODO check kernel flags
// This uses a lot of code from Caffe (http://caffe.berkeleyvision.org/); // This uses a lot of code from Caffe (http://caffe.berkeleyvision.org/);
// sources are clearly marked. Below we reproduce the original license of // sources are clearly marked. Below we reproduce the original license of
...@@ -35,6 +35,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -35,6 +35,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// GPU kernel for the case of dilation // GPU kernel for the case of dilation
KERNEL void dilated_im3d2col_kernel(const ga_size n, KERNEL void dilated_im3d2col_kernel(const ga_size n,
GLOBAL_MEM const DTYPE_INPUT_0 * data_im, GLOBAL_MEM const DTYPE_INPUT_0 * data_im,
const ga_size offset_im,
const ga_size data_im_offset, const ga_size data_im_offset,
const ga_size height, const ga_size width, const ga_size depth, const ga_size height, const ga_size width, const ga_size depth,
const ga_size kernel_h, const ga_size kernel_w, const ga_size kernel_d, const ga_size kernel_h, const ga_size kernel_w, const ga_size kernel_d,
...@@ -42,7 +43,10 @@ KERNEL void dilated_im3d2col_kernel(const ga_size n, ...@@ -42,7 +43,10 @@ KERNEL void dilated_im3d2col_kernel(const ga_size n,
const ga_size pad_h, const ga_size pad_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w, const ga_size pad_d,
const ga_size stride_h, const ga_size stride_w, const ga_size stride_d, const ga_size stride_h, const ga_size stride_w, const ga_size stride_d,
const ga_size height_col, const ga_size width_col, const ga_size depth_col, const ga_size height_col, const ga_size width_col, const ga_size depth_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col) { GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size offset_col) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -80,16 +84,20 @@ KERNEL void dilated_im3d2col_kernel(const ga_size n, ...@@ -80,16 +84,20 @@ KERNEL void dilated_im3d2col_kernel(const ga_size n,
} }
} }
#kernel im3d2col_kernel : size, *, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, * : #kernel im3d2col_kernel : size, *, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, *, size :
KERNEL void im3d2col_kernel(const ga_size n, KERNEL void im3d2col_kernel(const ga_size n,
GLOBAL_MEM const DTYPE_INPUT_0 * data_im, GLOBAL_MEM const DTYPE_INPUT_0 * data_im,
const ga_size offset_im,
const ga_size data_im_offset, const ga_size data_im_offset,
const ga_size height, const ga_size width, const ga_size depth, const ga_size height, const ga_size width, const ga_size depth,
const ga_size kernel_h, const ga_size kernel_w, const ga_size kernel_d, const ga_size kernel_h, const ga_size kernel_w, const ga_size kernel_d,
const ga_size pad_h, const ga_size pad_w, const ga_size pad_d, const ga_size pad_h, const ga_size pad_w, const ga_size pad_d,
const ga_size stride_h, const ga_size stride_w, const ga_size stride_d, const ga_size stride_h, const ga_size stride_w, const ga_size stride_d,
const ga_size height_col, const ga_size width_col, const ga_size depth_col, const ga_size height_col, const ga_size width_col, const ga_size depth_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_col) { GLOBAL_MEM DTYPE_INPUT_0 * data_col,
const ga_size data_im_offset) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -126,9 +134,10 @@ KERNEL void im3d2col_kernel(const ga_size n, ...@@ -126,9 +134,10 @@ KERNEL void im3d2col_kernel(const ga_size n,
} }
// GPU kernel for the case of dilation // GPU kernel for the case of dilation
#kernel dilated_col2im3d_kernel : size, *, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, *, size : #kernel dilated_col2im3d_kernel : size, *, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, *, size, size :
KERNEL void dilated_col2im3d_kernel(const ga_size n, KERNEL void dilated_col2im3d_kernel(const ga_size n,
GLOBAL_MEM const DTYPE_INPUT_0 * data_col, GLOBAL_MEM const DTYPE_INPUT_0 * data_col,
const ga_size offset_col,
const ga_size height, const ga_size width, const ga_size depth, const ga_size height, const ga_size width, const ga_size depth,
const ga_size channels, const ga_size channels,
const ga_size kernel_h, const ga_size kernel_w, const ga_size kernel_d, const ga_size kernel_h, const ga_size kernel_w, const ga_size kernel_d,
...@@ -137,7 +146,10 @@ KERNEL void dilated_col2im3d_kernel(const ga_size n, ...@@ -137,7 +146,10 @@ KERNEL void dilated_col2im3d_kernel(const ga_size n,
const ga_size stride_h, const ga_size stride_w, const ga_size stride_d, const ga_size stride_h, const ga_size stride_w, const ga_size stride_d,
const ga_size height_col, const ga_size width_col, const ga_size depth_col, const ga_size height_col, const ga_size width_col, const ga_size depth_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_im, GLOBAL_MEM DTYPE_INPUT_0 * data_im,
const ga_size offset_im,
const ga_size data_im_offset) { const ga_size data_im_offset) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -188,9 +200,10 @@ KERNEL void dilated_col2im3d_kernel(const ga_size n, ...@@ -188,9 +200,10 @@ KERNEL void dilated_col2im3d_kernel(const ga_size n,
} }
} }
#kernel col2im3d_kernel : size, *, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, *, size : #kernel col2im3d_kernel : size, *, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, size, *, size, size :
KERNEL void col2im3d_kernel(const ga_size n, KERNEL void col2im3d_kernel(const ga_size n,
GLOBAL_MEM const DTYPE_INPUT_0 * data_col, GLOBAL_MEM const DTYPE_INPUT_0 * data_col,
const ga_size offset_col,
const ga_size height, const ga_size width, const ga_size depth, const ga_size height, const ga_size width, const ga_size depth,
const ga_size channels, const ga_size channels,
const ga_size kernel_h, const ga_size kernel_w, const ga_size kernel_d, const ga_size kernel_h, const ga_size kernel_w, const ga_size kernel_d,
...@@ -198,7 +211,10 @@ KERNEL void col2im3d_kernel(const ga_size n, ...@@ -198,7 +211,10 @@ KERNEL void col2im3d_kernel(const ga_size n,
const ga_size stride_h, const ga_size stride_w, const ga_size stride_d, const ga_size stride_h, const ga_size stride_w, const ga_size stride_d,
const ga_size height_col, const ga_size width_col, const ga_size depth_col, const ga_size height_col, const ga_size width_col, const ga_size depth_col,
GLOBAL_MEM DTYPE_INPUT_0 * data_im, GLOBAL_MEM DTYPE_INPUT_0 * data_im,
const ga_size offset_im,
const ga_size data_im_offset) { const ga_size data_im_offset) {
data_im = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_im) + offset_im);
data_col = (GLOBAL_MEM DTYPE_INPUT_0 *)(((char *)data_col) + offset_col);
// grid stride looping // grid stride looping
for (ga_size index = GID_0 * LDIM_0 + LID_0; for (ga_size index = GID_0 * LDIM_0 + LID_0;
index < (n); index += LDIM_0 * GDIM_0) { index < (n); index += LDIM_0 * GDIM_0) {
...@@ -239,13 +255,13 @@ KERNEL void col2im3d_kernel(const ga_size n, ...@@ -239,13 +255,13 @@ KERNEL void col2im3d_kernel(const ga_size n,
#section support_code_struct #section support_code_struct
int im3d2col( int im3d2col(
gpudata * data_im, const size_t data_im_offset, const size_t channels, GpuArray *data_im, const size_t data_im_offset, const size_t channels,
const size_t height, const size_t width, const size_t depth, const size_t height, const size_t width, const size_t depth,
const size_t kernel_h, const size_t kernel_w, const size_t kernel_d, const size_t kernel_h, const size_t kernel_w, const size_t kernel_d,
const size_t dilation_h, const size_t dilation_w, const size_t dilation_d, const size_t dilation_h, const size_t dilation_w, const size_t dilation_d,
const size_t pad_h, const size_t pad_w, const size_t pad_d, const size_t pad_h, const size_t pad_w, const size_t pad_d,
const size_t stride_h, const size_t stride_w, const size_t stride_d, const size_t stride_h, const size_t stride_w, const size_t stride_d,
gpudata * data_col) { GpuArray *data_col) {
// We are going to launch channels * height_col * width_col * depth_col // We are going to launch channels * height_col * width_col * depth_col
// kernels, each kernel responsible for copying a single-channel grid. // kernels, each kernel responsible for copying a single-channel grid.
size_t dil_kernel_h = (kernel_h - 1) * dilation_h + 1; size_t dil_kernel_h = (kernel_h - 1) * dilation_h + 1;
...@@ -259,10 +275,11 @@ int im3d2col( ...@@ -259,10 +275,11 @@ int im3d2col(
if (dilation_h != 1 || dilation_w != 1 || dilation_d != 1) { if (dilation_h != 1 || dilation_w != 1 || dilation_d != 1) {
err = dilated_im3d2col_kernel_scall( err = dilated_im3d2col_kernel_scall(
1, &num_kernels, 0, 1, &num_kernels, 0,
num_kernels, data_im, data_im_offset, height, width, depth, num_kernels, data_im->data, data_im->offset,
data_im_offset, height, width, depth,
kernel_h, kernel_w, kernel_d, dilation_h, dilation_w, dilation_d, kernel_h, kernel_w, kernel_d, dilation_h, dilation_w, dilation_d,
pad_h, pad_w, pad_d, stride_h, stride_w, stride_d, height_col, pad_h, pad_w, pad_d, stride_h, stride_w, stride_d, height_col,
width_col, depth_col, data_col); width_col, depth_col, data_col->data, data_col->offset);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"gpuarray error: dilated_im3d2col_kernel: %s.", "gpuarray error: dilated_im3d2col_kernel: %s.",
...@@ -271,10 +288,11 @@ int im3d2col( ...@@ -271,10 +288,11 @@ int im3d2col(
} else { } else {
err = im3d2col_kernel_scall( err = im3d2col_kernel_scall(
1, &num_kernels, 0, 1, &num_kernels, 0,
num_kernels, data_im, data_im_offset, height, width, depth, num_kernels, data_im->data, data_im->offset,
data_im_offset, height, width, depth,
kernel_h, kernel_w, kernel_d, pad_h, pad_w, pad_d, kernel_h, kernel_w, kernel_d, pad_h, pad_w, pad_d,
stride_h, stride_w, stride_d, height_col, width_col, depth_col, stride_h, stride_w, stride_d, height_col, width_col, depth_col,
data_col); data_col->data, data_col->offset);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"gpuarray error: im3d2col_kernel: %s.", "gpuarray error: im3d2col_kernel: %s.",
...@@ -284,13 +302,13 @@ int im3d2col( ...@@ -284,13 +302,13 @@ int im3d2col(
return err; return err;
} }
int col2im3d(gpudata * data_col, const size_t channels, int col2im3d(GpuArray *data_col, const size_t channels,
const size_t height, const size_t width, const size_t depth, const size_t height, const size_t width, const size_t depth,
const size_t patch_h, const size_t patch_w, const size_t patch_d, const size_t patch_h, const size_t patch_w, const size_t patch_d,
const size_t dilation_h, const size_t dilation_w, const size_t dilation_d, const size_t dilation_h, const size_t dilation_w, const size_t dilation_d,
const size_t pad_h, const size_t pad_w, const size_t pad_d, const size_t pad_h, const size_t pad_w, const size_t pad_d,
const size_t stride_h, const size_t stride_w, const size_t stride_d, const size_t stride_h, const size_t stride_w, const size_t stride_d,
gpudata * data_im, const size_t data_im_offset) { GpuArray *data_im, const size_t data_im_offset) {
size_t dil_patch_h = (patch_h - 1) * dilation_h + 1; size_t dil_patch_h = (patch_h - 1) * dilation_h + 1;
size_t dil_patch_w = (patch_w - 1) * dilation_w + 1; size_t dil_patch_w = (patch_w - 1) * dilation_w + 1;
size_t dil_patch_d = (patch_d - 1) * dilation_d + 1; size_t dil_patch_d = (patch_d - 1) * dilation_d + 1;
...@@ -304,10 +322,11 @@ int col2im3d(gpudata * data_col, const size_t channels, ...@@ -304,10 +322,11 @@ int col2im3d(gpudata * data_col, const size_t channels,
if (dilation_h != 1 || dilation_w != 1 || dilation_d != 1) { if (dilation_h != 1 || dilation_w != 1 || dilation_d != 1) {
err = dilated_col2im3d_kernel_scall( err = dilated_col2im3d_kernel_scall(
1, &num_kernels, 0, 1, &num_kernels, 0,
num_kernels, data_col, height, width, depth, channels, patch_h, patch_w, num_kernels, data_col->data, data_col->offset,
height, width, depth, channels, patch_h, patch_w,
patch_d, dilation_h, dilation_w, dilation_d, pad_h, pad_w, pad_d, patch_d, dilation_h, dilation_w, dilation_d, pad_h, pad_w, pad_d,
stride_h, stride_w, stride_d, height_col, width_col, depth_col, stride_h, stride_w, stride_d, height_col, width_col, depth_col,
data_im, data_im_offset); data_im->data, data_im->offset, data_im_offset);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"gpuarray error: dilated_col2im3d_kernel: %s.", "gpuarray error: dilated_col2im3d_kernel: %s.",
...@@ -317,9 +336,11 @@ int col2im3d(gpudata * data_col, const size_t channels, ...@@ -317,9 +336,11 @@ int col2im3d(gpudata * data_col, const size_t channels,
else{ else{
err = col2im3d_kernel_scall( err = col2im3d_kernel_scall(
1, &num_kernels, 0, 1, &num_kernels, 0,
num_kernels, data_col, height, width, depth, channels, patch_h, patch_w, num_kernels, data_col->data, data_col->offset,
height, width, depth, channels, patch_h, patch_w,
patch_d, pad_h, pad_w, pad_d, stride_h, stride_w, stride_d, patch_d, pad_h, pad_w, pad_d, stride_h, stride_w, stride_d,
height_col, width_col, depth_col, data_im, data_im_offset); height_col, width_col, depth_col,
data_im->data, data_im->offset, data_im_offset);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"gpuarray error: col2im3d_kernel: %s.", "gpuarray error: col2im3d_kernel: %s.",
...@@ -503,9 +524,9 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -503,9 +524,9 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// First, im3d2col // First, im3d2col
err = im3d2col( err = im3d2col(
bottom->ga.data, n * bottom_stride, nChannels, bottomHeight, &bottom->ga, n * bottom_stride, nChannels, bottomHeight,
bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD, bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
padH, padW, padD, dH, dW, dD, col->ga.data); padH, padW, padD, dH, dW, dD, &col->ga);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
...@@ -565,9 +586,9 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -565,9 +586,9 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// First, im3d2col // First, im3d2col
err = im3d2col( err = im3d2col(
bottom->ga.data, n * bottom_stride, nChannels, bottomHeight, &bottom->ga, n * bottom_stride, nChannels, bottomHeight,
bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD, bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
padH, padW, padD, dH, dW, dD, col->ga.data); padH, padW, padD, dH, dW, dD, &col->ga);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
...@@ -673,10 +694,10 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -673,10 +694,10 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
return NULL; return NULL;
} }
// col2im3d back to the data // col2im3d back to the data
err = col2im3d(col->ga.data, nChannels, err = col2im3d(&col->ga, nChannels,
bottomHeight, bottomWidth, bottomDepth, bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, kH, kW, kD, dilH, dilW, dilD, padH, padW, padD,
dH, dW, dD, bottom->ga.data, n * bottom_stride); dH, dW, dD, &bottom->ga, n * bottom_stride);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论