提交 77bce880 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for float16/float64 to Corr3dMM.

上级 0c2eb3f0
...@@ -496,6 +496,7 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -496,6 +496,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
return [os.path.dirname(__file__)] return [os.path.dirname(__file__)]
def c_code_cache_version(self): def c_code_cache_version(self):
# Raise this whenever modifying the code below.
return (2,) return (2,)
def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None): def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None):
...@@ -958,7 +959,7 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM): ...@@ -958,7 +959,7 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM):
return [[1], [1], [0], [0]] # no connection to height, width return [[1], [1], [0], [0]] # no connection to height, width
class BaseGpuCorr3dMM(CGpuKernelBase, BlasOp): class BaseGpuCorr3dMM(CGpuKernelBase):
""" """
Base class for `GpuCorr3dMM`, `GpuCorr3dMM_gradWeights` and Base class for `GpuCorr3dMM`, `GpuCorr3dMM_gradWeights` and
`GpuCorr3dMM_gradInputs`. Cannot be used directly. `GpuCorr3dMM_gradInputs`. Cannot be used directly.
...@@ -972,10 +973,11 @@ class BaseGpuCorr3dMM(CGpuKernelBase, BlasOp): ...@@ -972,10 +973,11 @@ class BaseGpuCorr3dMM(CGpuKernelBase, BlasOp):
Perform subsampling of the output (default: (1, 1, 1)). Perform subsampling of the output (default: (1, 1, 1)).
filter_dilation filter_dilation
Perform subsampling of the input, also known as dilation (default: (1, 1, 1)). Perform subsampling of the input, also known as dilation (default: (1, 1, 1)).
"""
"""
check_broadcast = False check_broadcast = False
__props__ = ('border_mode', 'subsample', 'filter_dilation') __props__ = ('border_mode', 'subsample', 'filter_dilation')
_f16_ok = True
def __init__(self, border_mode="valid", subsample=(1, 1, 1), def __init__(self, border_mode="valid", subsample=(1, 1, 1),
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1)):
...@@ -1033,9 +1035,15 @@ class BaseGpuCorr3dMM(CGpuKernelBase, BlasOp): ...@@ -1033,9 +1035,15 @@ class BaseGpuCorr3dMM(CGpuKernelBase, BlasOp):
def get_params(self, node): def get_params(self, node):
return node.inputs[0].type.context return node.inputs[0].type.context
def c_headers(self):
return ["<gpuarray/array.h>", "<gpuarray/blas.h>", "gpuarray_helper.h"]
def c_header_dirs(self):
return [os.path.dirname(__file__)]
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files # raise this whenever modifying the code below.
return (0, 2) return (2,)
def c_code_helper(self, bottom, weights, top, direction, sub, def c_code_helper(self, bottom, weights, top, direction, sub,
height=None, width=None, depth=None): height=None, width=None, depth=None):
......
...@@ -236,11 +236,9 @@ KERNEL void col2im3d_kernel(const ga_size n, ...@@ -236,11 +236,9 @@ KERNEL void col2im3d_kernel(const ga_size n,
} }
} }
#section support_code_struct #section support_code_struct
int im3d2col(const size_t max_threads_dim, int im3d2col(
gpudata * data_im, const size_t data_im_offset, const size_t channels, gpudata * data_im, const size_t data_im_offset, const size_t channels,
const size_t height, const size_t width, const size_t depth, const size_t height, const size_t width, const size_t depth,
const size_t kernel_h, const size_t kernel_w, const size_t kernel_d, const size_t kernel_h, const size_t kernel_w, const size_t kernel_d,
...@@ -257,13 +255,10 @@ int im3d2col(const size_t max_threads_dim, ...@@ -257,13 +255,10 @@ int im3d2col(const size_t max_threads_dim,
size_t width_col = (width + 2 * pad_w - dil_kernel_w) / stride_w + 1; size_t width_col = (width + 2 * pad_w - dil_kernel_w) / stride_w + 1;
size_t depth_col = (depth + 2 * pad_d - dil_kernel_d) / stride_d + 1; size_t depth_col = (depth + 2 * pad_d - dil_kernel_d) / stride_d + 1;
size_t num_kernels = channels * height_col * width_col * depth_col; size_t num_kernels = channels * height_col * width_col * depth_col;
size_t threads_per_block = max_threads_dim;
size_t n_blocks = (num_kernels + threads_per_block - 1) / threads_per_block;
int err; int err;
GpuKernel *kernel; if (dilation_h != 1 || dilation_w != 1 || dilation_d != 1) {
if(dilation_h != 1 || dilation_w != 1 || dilation_d != 1){ err = dilated_im3d2col_kernel_scall(
err = dilated_im3d2col_kernel_call( 1, &num_kernels, 0,
1, &n_blocks, &threads_per_block, 0,
num_kernels, data_im, data_im_offset, height, width, depth, num_kernels, data_im, data_im_offset, height, width, depth,
kernel_h, kernel_w, kernel_d, dilation_h, dilation_w, dilation_d, kernel_h, kernel_w, kernel_d, dilation_h, dilation_w, dilation_d,
pad_h, pad_w, pad_d, stride_h, stride_w, stride_d, height_col, pad_h, pad_w, pad_d, stride_h, stride_w, stride_d, height_col,
...@@ -273,10 +268,9 @@ int im3d2col(const size_t max_threads_dim, ...@@ -273,10 +268,9 @@ int im3d2col(const size_t max_threads_dim,
"gpuarray error: dilated_im3d2col_kernel: %s.", "gpuarray error: dilated_im3d2col_kernel: %s.",
GpuKernel_error(&k_dilated_im3d2col_kernel, err)); GpuKernel_error(&k_dilated_im3d2col_kernel, err));
} }
} } else {
else{ err = im3d2col_kernel_scall(
err = im3d2col_kernel_call( 1, &num_kernels, 0,
1, &n_blocks, &threads_per_block, 0,
num_kernels, data_im, data_im_offset, height, width, depth, num_kernels, data_im, data_im_offset, height, width, depth,
kernel_h, kernel_w, kernel_d, pad_h, pad_w, pad_d, kernel_h, kernel_w, kernel_d, pad_h, pad_w, pad_d,
stride_h, stride_w, stride_d, height_col, width_col, depth_col, stride_h, stride_w, stride_d, height_col, width_col, depth_col,
...@@ -290,7 +284,7 @@ int im3d2col(const size_t max_threads_dim, ...@@ -290,7 +284,7 @@ int im3d2col(const size_t max_threads_dim,
return err; return err;
} }
int col2im3d(const size_t max_threads_dim, gpudata * data_col, const size_t channels, int col2im3d(gpudata * data_col, const size_t channels,
const size_t height, const size_t width, const size_t depth, const size_t height, const size_t width, const size_t depth,
const size_t patch_h, const size_t patch_w, const size_t patch_d, const size_t patch_h, const size_t patch_w, const size_t patch_d,
const size_t dilation_h, const size_t dilation_w, const size_t dilation_d, const size_t dilation_h, const size_t dilation_w, const size_t dilation_d,
...@@ -304,14 +298,12 @@ int col2im3d(const size_t max_threads_dim, gpudata * data_col, const size_t chan ...@@ -304,14 +298,12 @@ int col2im3d(const size_t max_threads_dim, gpudata * data_col, const size_t chan
size_t width_col = (width + 2 * pad_w - dil_patch_w) / stride_w + 1; size_t width_col = (width + 2 * pad_w - dil_patch_w) / stride_w + 1;
size_t depth_col = (depth + 2 * pad_d - dil_patch_d) / stride_d + 1; size_t depth_col = (depth + 2 * pad_d - dil_patch_d) / stride_d + 1;
size_t num_kernels = channels * height * width * depth; size_t num_kernels = channels * height * width * depth;
size_t threads_per_block = max_threads_dim;
size_t n_blocks = (num_kernels + threads_per_block - 1) / threads_per_block;
// To avoid involving atomic operations, we will launch one kernel per // To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions. // bottom dimension, and then in the kernel add up the top dimensions.
int err; int err;
if(dilation_h != 1 || dilation_w != 1 || dilation_d != 1){ if (dilation_h != 1 || dilation_w != 1 || dilation_d != 1) {
err = dilated_col2im3d_kernel_call( err = dilated_col2im3d_kernel_scall(
1, &n_blocks, &threads_per_block, 0, 1, &num_kernels, 0,
num_kernels, data_col, height, width, depth, channels, patch_h, patch_w, num_kernels, data_col, height, width, depth, channels, patch_h, patch_w,
patch_d, dilation_h, dilation_w, dilation_d, pad_h, pad_w, pad_d, patch_d, dilation_h, dilation_w, dilation_d, pad_h, pad_w, pad_d,
stride_h, stride_w, stride_d, height_col, width_col, depth_col, stride_h, stride_w, stride_d, height_col, width_col, depth_col,
...@@ -323,8 +315,8 @@ int col2im3d(const size_t max_threads_dim, gpudata * data_col, const size_t chan ...@@ -323,8 +315,8 @@ int col2im3d(const size_t max_threads_dim, gpudata * data_col, const size_t chan
} }
} }
else{ else{
err = col2im3d_kernel_call( err = col2im3d_kernel_scall(
1, &n_blocks, &threads_per_block, 0, 1, &num_kernels, 0,
num_kernels, data_col, height, width, depth, channels, patch_h, patch_w, num_kernels, data_col, height, width, depth, channels, patch_h, patch_w,
patch_d, pad_h, pad_w, pad_d, stride_h, stride_w, stride_d, patch_d, pad_h, pad_w, pad_d, stride_h, stride_w, stride_d,
height_col, width_col, depth_col, data_im, data_im_offset); height_col, width_col, depth_col, data_im, data_im_offset);
...@@ -460,15 +452,6 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -460,15 +452,6 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
return NULL; return NULL;
} }
// Get the max threads per blocks
size_t max_threads_dim;
err = gpucontext_property(bottom->context->ctx, GA_CTX_PROP_MAXLSIZE, &max_threads_dim);
if (err != GA_NO_ERROR){
PyErr_Format(PyExc_RuntimeError,
"Could not fetch max_threads_dim.");
return NULL;
}
// Create temporary columns // Create temporary columns
size_t col_dim[2]; size_t col_dim[2];
col_dim[0] = nChannels * kW * kH * kD; col_dim[0] = nChannels * kW * kH * kD;
...@@ -492,8 +475,6 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -492,8 +475,6 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
const size_t K_ = col_dim[0]; const size_t K_ = col_dim[0];
const size_t N_ = col_dim[1]; const size_t N_ = col_dim[1];
const size_t M_ = nFilters; const size_t M_ = nFilters;
const DTYPE_INPUT_0 one = 1.0f;
const DTYPE_INPUT_0 zero = 0.0f;
PyGpuArrayObject *output; PyGpuArrayObject *output;
if (direction == 0) { // forward pass if (direction == 0) { // forward pass
...@@ -502,7 +483,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -502,7 +483,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// First, im3d2col // First, im3d2col
err = im3d2col(max_threads_dim, err = im3d2col(
bottom->ga.data, n * bottom_stride, nChannels, bottomHeight, bottom->ga.data, n * bottom_stride, nChannels, bottomHeight,
bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD, bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
padH, padW, padD, dH, dW, dD, col->ga.data); padH, padW, padD, dH, dW, dD, col->ga.data);
...@@ -511,15 +492,37 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -511,15 +492,37 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
return NULL; return NULL;
} }
// Second, gemm // Second, gemm
switch (col->ga.typecode) {
case GA_FLOAT:
err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_no_trans, err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, one, N_, M_, K_, 1,
col->ga.data, 0, N_, col->ga.data, 0, N_,
weight->ga.data, 0, K_, weight->ga.data, 0, K_,
zero, 0,
top->ga.data, n * top_stride, N_); top->ga.data, n * top_stride, N_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1,
col->ga.data, 0, N_,
weight->ga.data, 0, K_,
0,
top->ga.data, n * top_stride, N_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1,
col->ga.data, 0, N_,
weight->ga.data, 0, K_,
0,
top->ga.data, n * top_stride, N_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered an error running sgemm.\n"); "(0) GpuCorr3dMM encountered an error running gemm.");
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
} }
...@@ -531,7 +534,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -531,7 +534,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// First, im3d2col // First, im3d2col
err = im3d2col(max_threads_dim, err = im3d2col(
bottom->ga.data, n * bottom_stride, nChannels, bottomHeight, bottom->ga.data, n * bottom_stride, nChannels, bottomHeight,
bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD, bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
padH, padW, padD, dH, dW, dD, col->ga.data); padH, padW, padD, dH, dW, dD, col->ga.data);
...@@ -543,15 +546,37 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -543,15 +546,37 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.) // is faster than setting weight to all zeros before the loop.)
switch (col->ga.typecode) {
case GA_FLOAT:
err = gpublas_sgemm(cb_fortran, cb_trans, cb_no_trans, err = gpublas_sgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, one, K_, M_, N_, 1,
col->ga.data, 0, N_, col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_, top->ga.data, n * top_stride, N_,
(n == 0) ? zero : one, (n == 0) ? 0 : 1,
weight->ga.data, 0, K_); weight->ga.data, 0, K_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1,
col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_,
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1,
col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_,
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered an error running sgemm.\n"); "(1) GpuCorr3dMM encountered an error running gemm.");
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
} }
...@@ -563,21 +588,42 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -563,21 +588,42 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// gemm into columns // gemm into columns
switch (top->ga.typecode) {
case GA_FLOAT:
err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_trans, err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, one, N_, K_, M_, 1,
top->ga.data, n * top_stride, N_, top->ga.data, n * top_stride, N_,
weight->ga.data, 0, K_, weight->ga.data, 0, K_,
zero, 0,
col->ga.data, 0, N_); col->ga.data, 0, N_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, 1,
top->ga.data, n * top_stride, N_,
weight->ga.data, 0, K_,
0,
col->ga.data, 0, N_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, 1,
top->ga.data, n * top_stride, N_,
weight->ga.data, 0, K_,
0,
col->ga.data, 0, N_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM encountered an error running sgemm.\n"); "(2) GpuCorr3dMM encountered an error running gemm.");
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
} }
// col2im3d back to the data // col2im3d back to the data
err = col2im3d(max_threads_dim, err = col2im3d(col->ga.data, nChannels,
col->ga.data, nChannels,
bottomHeight, bottomWidth, bottomDepth, bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, kH, kW, kD, dilH, dilW, dilD, padH, padW, padD,
dH, dW, dD, bottom->ga.data, n * bottom_stride); dH, dW, dD, bottom->ga.data, n * bottom_stride);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论