提交 63f8d7c2 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the C code so that it works.

上级 3df59ea7
...@@ -15,9 +15,10 @@ PyGpuArrayObject *rand_buf; ...@@ -15,9 +15,10 @@ PyGpuArrayObject *rand_buf;
int gemm16(PyGpuArrayObject *C, float alpha, int gemm16(PyGpuArrayObject *C, float alpha,
PyGpuArrayObject *A, PyGpuArrayObject *B, PyGpuArrayObject *A, PyGpuArrayObject *B,
float beta, PyGpuArrayObject **out) { float beta, PyGpuArrayObject **out) {
PyGpuArrayObject *AA = NULL; PyGpuArrayObject *_A = NULL;
PyGpuArrayObject *BB = NULL; PyGpuArrayObject *_B = NULL;
GpuKernel *gk; GpuKernel *gk;
char *prand, *pA, *pB, *pout;
void *params[13]; void *params[13];
size_t grid[2]; size_t grid[2];
size_t threads[2]; size_t threads[2];
...@@ -29,6 +30,7 @@ int gemm16(PyGpuArrayObject *C, float alpha, ...@@ -29,6 +30,7 @@ int gemm16(PyGpuArrayObject *C, float alpha,
int vec = 0; int vec = 0;
static unsigned int nprocs = 0; static unsigned int nprocs = 0;
char opA, opB; char opA, opB;
if (GpuArray_CHKFLAGS(&A->ga, GA_FARRAY) && if (GpuArray_CHKFLAGS(&A->ga, GA_FARRAY) &&
GpuArray_CHKFLAGS(&B->ga, GA_FARRAY)) { GpuArray_CHKFLAGS(&B->ga, GA_FARRAY)) {
/* /*
...@@ -38,21 +40,29 @@ int gemm16(PyGpuArrayObject *C, float alpha, ...@@ -38,21 +40,29 @@ int gemm16(PyGpuArrayObject *C, float alpha,
*/ */
if (PyGpuArray_DIM(A, 0) * PyGpuArray_DIM(A, 1) < if (PyGpuArray_DIM(A, 0) * PyGpuArray_DIM(A, 1) <
PyGpuArray_DIM(B, 0) * PyGpuArray_DIM(B, 1)) { PyGpuArray_DIM(B, 0) * PyGpuArray_DIM(B, 1)) {
AA = pygpu_copy(A, GA_C_ORDER); _A = pygpu_copy(A, GA_C_ORDER);
if (AA == NULL) { if (_A == NULL) {
res = 1; res = 1;
goto cleanup; goto cleanup;
} }
BB = B; /*
Py_INCREF(BB); * This is not an extra reference on _A so don't add an INCREF.
* Also, we don't lose the ref on A since our caller will deal
* with it.
*/
A = _A;
} else { } else {
BB = pygpu_copy(B, GA_C_ORDER); _B = pygpu_copy(B, GA_C_ORDER);
if (BB == NULL) { if (_B == NULL) {
res = 1; res = 1;
goto cleanup; goto cleanup;
} }
AA = A; /*
Py_INCREF(AA); * This is not an extra reference on _B so don't add an INCREF
* Also, we don't lose the ref on B since our caller will deal
* with it.
*/
B = _B;
} }
} }
if (GEMM16_INPLACE && GpuArray_CHKFLAGS(&C->ga, GA_CARRAY)) { if (GEMM16_INPLACE && GpuArray_CHKFLAGS(&C->ga, GA_CARRAY)) {
...@@ -67,18 +77,31 @@ int gemm16(PyGpuArrayObject *C, float alpha, ...@@ -67,18 +77,31 @@ int gemm16(PyGpuArrayObject *C, float alpha,
} }
} }
if (GpuArray_CHKFLAGS(&A->ga, GA_FARRAY)) if (GpuArray_CHKFLAGS(&A->ga, GA_FARRAY)) {
opA = 't'; opA = 't';
else lda = PyGpuArray_STRIDE(A, 1);
} else {
opA = 'n'; opA = 'n';
lda = PyGpuArray_STRIDE(A, 0);
}
if (GpuArray_CHKFLAGS(&B->ga, GA_FARRAY)) if (GpuArray_CHKFLAGS(&B->ga, GA_FARRAY)) {
opB = 't'; opB = 't';
else ldb = PyGpuArray_STRIDE(B, 1);
} else {
opB = 'n'; opB = 'n';
ldb = PyGpuArray_STRIDE(B, 0);
}
ldc = PyGpuArray_STRIDE(*out, 0);
m = PyGpuArray_DIM(C, 0); /* lda and friend are in number of elements, not bytes */
n = PyGpuArray_DIM(C, 1); lda /= 2;
ldb /= 2;
ldc /= 2;
m = PyGpuArray_DIM(*out, 0);
n = PyGpuArray_DIM(*out, 1);
k = PyGpuArray_DIM(B, 0); k = PyGpuArray_DIM(B, 0);
/* Tuning code adapted from the python version */ /* Tuning code adapted from the python version */
...@@ -93,7 +116,7 @@ int gemm16(PyGpuArrayObject *C, float alpha, ...@@ -93,7 +116,7 @@ int gemm16(PyGpuArrayObject *C, float alpha,
if (48 < n128 && n128 <= 64) { if (48 < n128 && n128 <= 64) {
n64 = n / 64; n64 = n / 64;
if (nprocs == 0) if (nprocs == 0)
if (C->ga.ops->property(C->context->ctx, NULL, NULL, if (A->ga.ops->property(A->context->ctx, NULL, NULL,
GA_CTX_PROP_NUMPROCS, &nprocs)) { GA_CTX_PROP_NUMPROCS, &nprocs)) {
nprocs = 0; nprocs = 0;
res = 1; res = 1;
...@@ -124,7 +147,7 @@ int gemm16(PyGpuArrayObject *C, float alpha, ...@@ -124,7 +147,7 @@ int gemm16(PyGpuArrayObject *C, float alpha,
if ((opA == 't' && opB == 'n' && m % 8 == 0 && n % 8 == 0) || if ((opA == 't' && opB == 'n' && m % 8 == 0 && n % 8 == 0) ||
(opA == 'n' && opB == 'n' && k % 16 == 0 && n % 8 == 0) || (opA == 'n' && opB == 'n' && k % 16 == 0 && n % 8 == 0) ||
(opA == 'n' && opB == '1' && k % 16 == 0)) (opA == 'n' && opB == 't' && k % 16 == 0))
vec = 1; vec = 1;
switch (size) { switch (size) {
...@@ -178,10 +201,18 @@ int gemm16(PyGpuArrayObject *C, float alpha, ...@@ -178,10 +201,18 @@ int gemm16(PyGpuArrayObject *C, float alpha,
goto cleanup; goto cleanup;
} }
params[0] = ((char *)rand_buf->ga.data) + rand_buf->ga.offset; prand = *((char **)rand_buf->ga.data);
params[1] = ((char *)A->ga.data) + A->ga.offset; prand += rand_buf->ga.offset;
params[2] = ((char *)B->ga.data) + B->ga.offset; pA = *((char **)A->ga.data);
params[3] = ((char *)C->ga.data) + C->ga.offset; pA += A->ga.offset;
pB = *((char **)B->ga.data);
pB += B->ga.offset;
pout = *((char **)(*out)->ga.data);
pout += (*out)->ga.offset;
params[0] = &prand;
params[1] = &pA;
params[2] = &pB;
params[3] = &pout;
params[4] = &lda; params[4] = &lda;
params[5] = &ldb; params[5] = &ldb;
params[6] = &ldc; params[6] = &ldc;
...@@ -192,17 +223,13 @@ int gemm16(PyGpuArrayObject *C, float alpha, ...@@ -192,17 +223,13 @@ int gemm16(PyGpuArrayObject *C, float alpha,
params[11] = &beta; params[11] = &beta;
params[12] = &flags; params[12] = &flags;
printf("%c%c_%s128x%d\n", opA, opB, vec ? "vec_" : "", size);
printf("%p %p %p %p\n", *param[0], *param[1], *param[2], *param[3]);
if (GpuKernel_call(gk, 2, threads, grid, 0, params) != GA_NO_ERROR) { if (GpuKernel_call(gk, 2, threads, grid, 0, params) != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "error in gemm16 kernel call"); PyErr_SetString(PyExc_RuntimeError, "error in gemm16 kernel call");
res = 1; res = 1;
} }
cleanup: cleanup:
Py_XDECREF(AA); Py_XDECREF(_A);
Py_XDECREF(BB); Py_XDECREF(_B);
return res; return res;
} }
...@@ -58,7 +58,6 @@ class Gemm16(COp): ...@@ -58,7 +58,6 @@ class Gemm16(COp):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self._use_c_code = False
def make_node(self, C, alpha, A, B, beta): def make_node(self, C, alpha, A, B, beta):
if GPUTensor is None: if GPUTensor is None:
...@@ -87,7 +86,7 @@ class Gemm16(COp): ...@@ -87,7 +86,7 @@ class Gemm16(COp):
else: else:
B = B.copy() B = B.copy()
inplace = self.inplace inplace = self.inplace
if inplace and not C.flags.forc: if inplace and not C.flags.c_contiguous:
inplace = False inplace = False
if not inplace: if not inplace:
C = C.copy() C = C.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论