提交 56333682 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix offsets for blas.

上级 9ce9aa3e
...@@ -252,6 +252,40 @@ KERNEL void col2im3d_kernel(const ga_size n, ...@@ -252,6 +252,40 @@ KERNEL void col2im3d_kernel(const ga_size n,
} }
} }
#section support_code
int rgemm(cb_order o, cb_transpose tA, cb_transpose tB,
size_t M, size_t N, size_t K, double alpha,
GpuArray *A, size_t offA, size_t lda,
GpuArray *B, size_t offB, size_t ldb,
double beta, GpuArray *C, size_t offC, size_t ldc) {
switch (A->typecode) {
case GA_FLOAT:
return gpublas_sgemm(o, tA, tB,
M, N, K, alpha,
A->data, (A->offset / 4) + offA, lda,
B->data, (B->offset / 4) + offB, ldb,
beta,
C->data, (C->offset / 4) + offC, ldc);
case GA_DOUBLE:
return gpublas_dgemm(o, tA, tB,
M, N, K, alpha,
A->data, (A->offset / 8) + offA, lda,
B->data, (B->offset / 8) + offB, ldb,
beta,
C->data, (C->offset / 8) + offC, ldc);
case GA_HALF:
return gpublas_hgemm(o, tA, tB,
M, N, K, alpha,
A->data, (A->offset / 2) + offA, lda,
B->data, (B->offset / 2) + offB, ldb,
beta,
C->data, (C->offset / 2) + offC, ldc);
default:
return GA_UNSUPPORTED_ERROR;
}
}
#section support_code_struct #section support_code_struct
int im3d2col( int im3d2col(
...@@ -532,34 +566,12 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -532,34 +566,12 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
return NULL; return NULL;
} }
// Second, gemm // Second, gemm
switch (col->ga.typecode) { err = rgemm(cb_fortran, cb_no_trans, cb_no_trans,
case GA_FLOAT: N_, M_, K_, 1,
err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_no_trans, &col->ga, 0, N_,
N_, M_, K_, 1, &weight->ga, 0, K_,
col->ga.data, 0, N_, 0,
weight->ga.data, 0, K_, &top->ga, n * top_stride, N_);
0,
top->ga.data, n * top_stride, N_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1,
col->ga.data, 0, N_,
weight->ga.data, 0, K_,
0,
top->ga.data, n * top_stride, N_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1,
col->ga.data, 0, N_,
weight->ga.data, 0, K_,
0,
top->ga.data, n * top_stride, N_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM forward encountered an error running gemm."); "GpuCorr3dMM forward encountered an error running gemm.");
...@@ -597,34 +609,12 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -597,34 +609,12 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.) // is faster than setting weight to all zeros before the loop.)
switch (col->ga.typecode) { err = rgemm(cb_fortran, cb_trans, cb_no_trans,
case GA_FLOAT: K_, M_, N_, 1,
err = gpublas_sgemm(cb_fortran, cb_trans, cb_no_trans, &col->ga, 0, N_,
K_, M_, N_, 1, &top->ga, n * top_stride, N_,
col->ga.data, 0, N_, (n == 0) ? 0 : 1,
top->ga.data, n * top_stride, N_, &weight->ga, 0, K_);
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1,
col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_,
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1,
col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_,
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad weights encountered an error running gemm."); "GpuCorr3dMM grad weights encountered an error running gemm.");
...@@ -659,34 +649,12 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -659,34 +649,12 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// gemm into columns // gemm into columns
switch (top->ga.typecode) { err = rgemm(cb_fortran, cb_no_trans, cb_trans,
case GA_FLOAT: N_, K_, M_, 1,
err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_trans, &top->ga, n * top_stride, N_,
N_, K_, M_, 1, &weight->ga, 0, K_,
top->ga.data, n * top_stride, N_, 0,
weight->ga.data, 0, K_, &col->ga, 0, N_);
0,
col->ga.data, 0, N_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, 1,
top->ga.data, n * top_stride, N_,
weight->ga.data, 0, K_,
0,
col->ga.data, 0, N_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, 1,
top->ga.data, n * top_stride, N_,
weight->ga.data, 0, K_,
0,
col->ga.data, 0, N_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad inputs encountered an error running gemm."); "GpuCorr3dMM grad inputs encountered an error running gemm.");
......
...@@ -205,7 +205,39 @@ KERNEL void col2im_kernel(const ga_size n, ...@@ -205,7 +205,39 @@ KERNEL void col2im_kernel(const ga_size n,
} }
} }
#section support_code
int rgemm(cb_order o, cb_transpose tA, cb_transpose tB,
size_t M, size_t N, size_t K, double alpha,
GpuArray *A, size_t offA, size_t lda,
GpuArray *B, size_t offB, size_t ldb,
double beta, GpuArray *C, size_t offC, size_t ldc) {
switch (A->typecode) {
case GA_FLOAT:
return gpublas_sgemm(o, tA, tB,
M, N, K, alpha,
A->data, (A->offset / 4) + offA, lda,
B->data, (B->offset / 4) + offB, ldb,
beta,
C->data, (C->offset / 4) + offC, ldc);
case GA_DOUBLE:
return gpublas_dgemm(o, tA, tB,
M, N, K, alpha,
A->data, (A->offset / 8) + offA, lda,
B->data, (B->offset / 8) + offB, ldb,
beta,
C->data, (C->offset / 8) + offC, ldc);
case GA_HALF:
return gpublas_hgemm(o, tA, tB,
M, N, K, alpha,
A->data, (A->offset / 2) + offA, lda,
B->data, (B->offset / 2) + offB, ldb,
beta,
C->data, (C->offset / 2) + offC, ldc);
default:
return GA_UNSUPPORTED_ERROR;
}
}
#section support_code_struct #section support_code_struct
...@@ -460,34 +492,12 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -460,34 +492,12 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
return NULL; return NULL;
} }
// Second, gemm // Second, gemm
switch (col->ga.typecode) { err = rgemm(cb_fortran, cb_no_trans, cb_no_trans,
case GA_FLOAT: N_, M_, K_, 1,
err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_no_trans, &col->ga, 0, N_,
N_, M_, K_, 1, &weight->ga, 0, K_,
col->ga.data, 0, N_, 0,
weight->ga.data, 0, K_, &top->ga, n * top_stride, N_);
0,
top->ga.data, n * top_stride, N_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1,
col->ga.data, 0, N_,
weight->ga.data, 0, K_,
0,
top->ga.data, n * top_stride, N_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1,
col->ga.data, 0, N_,
weight->ga.data, 0, K_,
0,
top->ga.data, n * top_stride, N_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM forward encountered an error running gemm: %d", err); "GpuCorrMM forward encountered an error running gemm: %d", err);
...@@ -525,34 +535,12 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -525,34 +535,12 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.) // is faster than setting weight to all zeros before the loop.)
switch (col->ga.typecode) { err = rgemm(cb_fortran, cb_trans, cb_no_trans,
case GA_FLOAT: K_, M_, N_, 1,
err = gpublas_sgemm(cb_fortran, cb_trans, cb_no_trans, &col->ga, 0, N_,
K_, M_, N_, 1, &top->ga, n * top_stride, N_,
col->ga.data, 0, N_, (n == 0) ? 0 : 1,
top->ga.data, n * top_stride, N_, &weight->ga, 0, K_);
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1,
col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_,
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1,
col->ga.data, 0, N_,
top->ga.data, n * top_stride, N_,
(n == 0) ? 0 : 1,
weight->ga.data, 0, K_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad weights encountered an error running gemm: %d", err); "GpuCorrMM grad weights encountered an error running gemm: %d", err);
...@@ -577,35 +565,13 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -577,35 +565,13 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
// full convolution: gemm, then col2im // full convolution: gemm, then col2im
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// gemm into columns // gemm into columns
switch (top->ga.typecode) { err = rgemm(cb_fortran, cb_no_trans, cb_trans,
case GA_FLOAT: N_, K_, M_, 1,
err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_trans, &top->ga, n * top_stride, N_,
N_, K_, M_, 1, &weight->ga, 0, K_,
top->ga.data, n * top_stride, N_, 0,
weight->ga.data, 0, K_, &col->ga, 0, N_);
0,
col->ga.data, 0, N_);
break;
case GA_DOUBLE:
err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, 1,
top->ga.data, n * top_stride, N_,
weight->ga.data, 0, K_,
0,
col->ga.data, 0, N_);
break;
case GA_HALF:
err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, 1,
top->ga.data, n * top_stride, N_,
weight->ga.data, 0, K_,
0,
col->ga.data, 0, N_);
break;
default:
err = GA_UNSUPPORTED_ERROR;
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad inputs encountered an error running gemm: %d", err); "GpuCorrMM grad inputs encountered an error running gemm: %d", err);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论