提交 e3e0e6d4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for float64 and float16.

上级 d9dfffd4
...@@ -414,7 +414,7 @@ gpugemmbatch_no_inplace = GpuGemmBatch(inplace=False) ...@@ -414,7 +414,7 @@ gpugemmbatch_no_inplace = GpuGemmBatch(inplace=False)
gpugemmbatch_inplace = GpuGemmBatch(inplace=True) gpugemmbatch_inplace = GpuGemmBatch(inplace=True)
class BaseGpuCorrMM(CGpuKernelBase, BlasOp): class BaseGpuCorrMM(CGpuKernelBase):
""" """
Base class for `GpuCorrMM`, `GpuCorrMM_gradWeights` and Base class for `GpuCorrMM`, `GpuCorrMM_gradWeights` and
`GpuCorrMM_gradInputs`. Cannot be used directly. `GpuCorrMM_gradInputs`. Cannot be used directly.
...@@ -429,9 +429,9 @@ class BaseGpuCorrMM(CGpuKernelBase, BlasOp): ...@@ -429,9 +429,9 @@ class BaseGpuCorrMM(CGpuKernelBase, BlasOp):
filter_dilation filter_dilation
Perform subsampling of the input, also known as dilation (default: (1, 1)). Perform subsampling of the input, also known as dilation (default: (1, 1)).
""" """
check_broadcast = False check_broadcast = False
__props__ = ('border_mode', 'subsample', 'filter_dilation') __props__ = ('border_mode', 'subsample', 'filter_dilation')
_f16_ok = True
def __init__(self, border_mode="valid", subsample=(1, 1), def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1)): filter_dilation=(1, 1)):
...@@ -489,9 +489,14 @@ class BaseGpuCorrMM(CGpuKernelBase, BlasOp): ...@@ -489,9 +489,14 @@ class BaseGpuCorrMM(CGpuKernelBase, BlasOp):
def get_params(self, node): def get_params(self, node):
return node.inputs[0].type.context return node.inputs[0].type.context
def c_headers(self):
return ["<gpuarray/array.h>", "<gpuarray/blas.h>", "gpuarray_helper.h"]
def c_header_dirs(self):
return [os.path.dirname(__file__)]
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files return (2,)
return (0, 2)
def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None): def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None):
""" """
......
...@@ -407,8 +407,6 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -407,8 +407,6 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
const size_t K_ = col_dim[0]; const size_t K_ = col_dim[0];
const size_t N_ = col_dim[1]; const size_t N_ = col_dim[1];
const size_t M_ = nFilters; const size_t M_ = nFilters;
const DTYPE_INPUT_0 one = 1.0f;
const DTYPE_INPUT_0 zero = 0.0f;
PyGpuArrayObject *output; PyGpuArrayObject *output;
if (direction == 0) { // forward pass if (direction == 0) { // forward pass
...@@ -426,15 +424,37 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -426,15 +424,37 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
return NULL; return NULL;
} }
// Second, gemm // Second, gemm
switch (col->ga.typecode) {
case GA_FLOAT:
err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_no_trans, err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, one, N_, M_, K_, 1,
col->ga.data, 0, N_, col->ga.data, 0, N_,
weight->ga.data, 0, K_, weight->ga.data, 0, K_,
zero, 0,
top->ga.data, n * top_stride, N_); top->ga.data, n * top_stride, N_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1,
col->ga.data, 0, N_,
weight->ga.data, 0, K_,
0,
top->ga.data, n * top_stride, N_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1,
col->ga.data, 0, N_,
weight->ga.data, 0, K_,
0,
top->ga.data, n * top_stride, N_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered an error running sgemm.\n"); "(0) GpuCorrMM encountered an error running gemm: %d", err);
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
} }
...@@ -458,15 +478,37 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -458,15 +478,37 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.) // is faster than setting weight to all zeros before the loop.)
switch (col->ga.typecode) {
case GA_FLOAT:
err = gpublas_sgemm(cb_fortran, cb_trans, cb_no_trans, err = gpublas_sgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, one, K_, M_, N_, 1,
col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_,
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1,
col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_,
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1,
col->ga.data, 0, N_, col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_, top->ga.data, n * top_stride, N_,
(n == 0) ? zero : one, (n == 0) ? 0 : 1,
weight->ga.data, 0, K_); weight->ga.data, 0, K_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered an error running sgemm.\n"); "(1) GpuCorrMM encountered an error running gemm: %d", err);
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论