提交 9ded03f3 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

GpuBatchedDot: streams implementation for large matrices

上级 e088e2a8
...@@ -39,26 +39,17 @@ class GpuBatchedDot(GpuOp): ...@@ -39,26 +39,17 @@ class GpuBatchedDot(GpuOp):
bx, by = input_names bx, by = input_names
bz, = output_names bz, = output_names
fail = sub['fail'] fail = sub['fail']
return """ return ("""
float alpha = 1.0; float alpha = 1.0;
float beta = 0.0; float beta = 0.0;
int i, x_dim0, x_dim1, x_dim2, y_dim0, y_dim1, y_dim2; int i, x_dim0, x_dim1, x_dim2, y_dim0, y_dim1, y_dim2;
int x_stride, y_stride, z_stride, total_size; int x_stride, y_stride, z_stride, total_size;
int ptr_array_size = 3 * CudaNdarray_HOST_DIMS(%(bx)s)[0] * sizeof(float *);
int out_dim[3]; int out_dim[3];
cublasStatus_t err; cublasStatus_t err;
cudaError_t err1; cudaError_t err1;
float **host_x = NULL;
float **host_z = NULL;
float **host_y = NULL;
float **gpu_x = NULL;
float **gpu_y = NULL;
float **gpu_z = NULL;
x_dim0 = CudaNdarray_HOST_DIMS(%(bx)s)[0]; x_dim0 = CudaNdarray_HOST_DIMS(%(bx)s)[0];
x_dim1 = CudaNdarray_HOST_DIMS(%(bx)s)[1]; x_dim1 = CudaNdarray_HOST_DIMS(%(bx)s)[1];
x_dim2 = CudaNdarray_HOST_DIMS(%(bx)s)[2]; x_dim2 = CudaNdarray_HOST_DIMS(%(bx)s)[2];
...@@ -67,6 +58,9 @@ class GpuBatchedDot(GpuOp): ...@@ -67,6 +58,9 @@ class GpuBatchedDot(GpuOp):
y_dim1 = CudaNdarray_HOST_DIMS(%(by)s)[1]; y_dim1 = CudaNdarray_HOST_DIMS(%(by)s)[1];
y_dim2 = CudaNdarray_HOST_DIMS(%(by)s)[2]; y_dim2 = CudaNdarray_HOST_DIMS(%(by)s)[2];
// use parallel cublasSgemm calls rather than cublasSgemmBatched for large products
bool use_cublas_sgemm_batched = x_dim1 * x_dim2 * y_dim2 < 128 * 128 * 128;
if (x_dim0 != y_dim0) if (x_dim0 != y_dim0)
{ {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
...@@ -105,10 +99,28 @@ class GpuBatchedDot(GpuOp): ...@@ -105,10 +99,28 @@ class GpuBatchedDot(GpuOp):
} }
} }
if (x_dim0 != 0 && y_dim0 != 0 && if (x_dim0 == 0 || y_dim0 == 0 || x_dim1 == 0 || y_dim1 == 0 || x_dim2 == 0 || y_dim2 == 0)
x_dim1 != 0 && y_dim1 != 0 && {
x_dim2 != 0 && y_dim2 != 0) total_size = x_dim0 * x_dim1 * y_dim2 * sizeof(float);
if (cudaSuccess != cudaMemset(CudaNdarray_DEV_DATA(%(bz)s), 0, total_size))
{
PyErr_Format(PyExc_RuntimeError,
"Failed to fill output with zeros");
%(fail)s;
}
}
else if (use_cublas_sgemm_batched)
{ {
int ptr_array_size = 3 * CudaNdarray_HOST_DIMS(%(bx)s)[0] * sizeof(float *);
float **host_x = NULL;
float **host_z = NULL;
float **host_y = NULL;
float **gpu_x = NULL;
float **gpu_y = NULL;
float **gpu_z = NULL;
x_stride = CudaNdarray_HOST_STRIDES(%(bx)s)[0]; x_stride = CudaNdarray_HOST_STRIDES(%(bx)s)[0];
y_stride = CudaNdarray_HOST_STRIDES(%(by)s)[0]; y_stride = CudaNdarray_HOST_STRIDES(%(by)s)[0];
z_stride = CudaNdarray_HOST_STRIDES(%(bz)s)[0]; z_stride = CudaNdarray_HOST_STRIDES(%(bz)s)[0];
...@@ -171,19 +183,130 @@ class GpuBatchedDot(GpuOp): ...@@ -171,19 +183,130 @@ class GpuBatchedDot(GpuOp):
err, cublasGetErrorString(err)); err, cublasGetErrorString(err));
%(fail)s; %(fail)s;
} }
} } else {
else // copy inputs if not contiguous
{ """ +
total_size = x_dim0 * x_dim1 * y_dim2 * sizeof(float); ("\n".join("""
if (cudaSuccess != cudaMemset(CudaNdarray_DEV_DATA(%(bz)s), 0, total_size)) if (( CudaNdarray_HOST_DIMS(%(var)s)[0] > 1 && CudaNdarray_HOST_STRIDES(%(var)s)[0] != 1
&& CudaNdarray_HOST_DIMS(%(var)s)[1] > 1 && CudaNdarray_HOST_STRIDES(%(var)s)[1] != 1
&& CudaNdarray_HOST_DIMS(%(var)s)[2] > 1 && CudaNdarray_HOST_STRIDES(%(var)s)[2] != 1)
|| CudaNdarray_HOST_STRIDES(%(var)s)[0] < 0
|| CudaNdarray_HOST_STRIDES(%(var)s)[1] < 0
|| CudaNdarray_HOST_STRIDES(%(var)s)[2] < 0)
{
CudaNdarray *_copy = (CudaNdarray*) CudaNdarray_Copy(%(var)s);
if (!_copy)
%(fail)s;
Py_XDECREF(%(var)s);
%(var)s = _copy;
}
""" % dict(var=var, fail=fail) for var in (bx, by)))
+ """
// fail if the output is not contiguous; we can't copy it because we
// need to write to the original memory
if (( CudaNdarray_HOST_DIMS(%(bz)s)[0] > 1 && CudaNdarray_HOST_STRIDES(%(bz)s)[0] != 1
&& CudaNdarray_HOST_DIMS(%(bz)s)[1] > 1 && CudaNdarray_HOST_STRIDES(%(bz)s)[1] != 1
&& CudaNdarray_HOST_DIMS(%(bz)s)[2] > 1 && CudaNdarray_HOST_STRIDES(%(bz)s)[2] != 1)
|| CudaNdarray_HOST_STRIDES(%(bz)s)[0] < 0
|| CudaNdarray_HOST_STRIDES(%(bz)s)[1] < 0
|| CudaNdarray_HOST_STRIDES(%(bz)s)[2] < 0)
{ {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_AssertionError,
"Failed to fill output with zeros"); "non-unit or negative stride in output arg %(bz)s (%%i, %%i, %%i) of shape (%%i, %%i, %%i)",
CudaNdarray_HOST_STRIDES(%(bz)s)[0],
CudaNdarray_HOST_STRIDES(%(bz)s)[1],
CudaNdarray_HOST_STRIDES(%(bz)s)[2],
CudaNdarray_HOST_DIMS(%(bz)s)[0],
CudaNdarray_HOST_DIMS(%(bz)s)[1],
CudaNdarray_HOST_DIMS(%(bz)s)[2]);
%(fail)s; %(fail)s;
} }
}
""" % locals() const int *Nx = CudaNdarray_HOST_DIMS(%(bx)s), *Sx = CudaNdarray_HOST_STRIDES(%(bx)s);
const int *Ny = CudaNdarray_HOST_DIMS(%(by)s), *Sy = CudaNdarray_HOST_STRIDES(%(by)s);
const int *Nz = CudaNdarray_HOST_DIMS(%(bz)s), *Sz = CudaNdarray_HOST_STRIDES(%(bz)s);
/* encode the stride structure of _x,_y,_z into a single integer. */
int unit = 0;
unit |= ((Sx[2] == 1 || Nx[2] == 1) ? 0x0 : (Sx[1] == 1 || Nx[1] == 1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == 1 || Ny[2] == 1) ? 0x0 : (Sy[1] == 1 || Ny[1] == 1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == 1 || Nz[2] == 1) ? 0x0 : (Sz[1] == 1 || Nz[1] == 1) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int sx_1 = (Nx[1] > 1) ? Sx[1] : (Nx[2] + 1);
int sx_2 = (Nx[2] > 1) ? Sx[2] : (Nx[1] + 1);
int sy_1 = (Ny[1] > 1) ? Sy[1] : (Ny[2] + 1);
int sy_2 = (Ny[2] > 1) ? Sy[2] : (Ny[1] + 1);
int sz_1 = (Nz[1] > 1) ? Sz[1] : (Nz[2] + 1);
int sz_2 = (Nz[2] > 1) ? Sz[2] : (Nz[1] + 1);
cublasOperation_t N = CUBLAS_OP_N, T = CUBLAS_OP_T;
float* x = CudaNdarray_DEV_DATA(%(bx)s);
float* y = CudaNdarray_DEV_DATA(%(by)s);
float* z = CudaNdarray_DEV_DATA(%(bz)s);
float* xend = x + CudaNdarray_SIZE(%(bx)s);
float* yend = y + CudaNdarray_SIZE(%(by)s);
float* zend = z + CudaNdarray_SIZE(%(bz)s);
float alpha = 1, beta = 0;
#define N_STREAMS 32
cudaStream_t streams[N_STREAMS];
for (int i = 0; i < N_STREAMS; i++) {
cudaStreamCreate(&streams[i]);
}
cudaStreamSynchronize(0);
for (int i = 0; i < Nx[0]; i++)
{
assert(CudaNdarray_DEV_DATA(%(bx)s) <= x); assert(x < CudaNdarray_DEV_DATA(%(bx)s) + CudaNdarray_SIZE(%(bx)s));
assert(CudaNdarray_DEV_DATA(%(by)s) <= y); assert(y < CudaNdarray_DEV_DATA(%(by)s) + CudaNdarray_SIZE(%(by)s));
assert(CudaNdarray_DEV_DATA(%(bz)s) <= z); assert(z < CudaNdarray_DEV_DATA(%(bz)s) + CudaNdarray_SIZE(%(bz)s));
cublasSetStream(handle, streams[i %% N_STREAMS]);
cublasStatus_t status;
switch(unit)
{
case 0x000: status = cublasSgemm(handle, N, N, Nz[2], Nz[1], Nx[2], &alpha, y, sy_1, x, sx_1, &beta, z, sz_1); break;
case 0x100: status = cublasSgemm(handle, N, T, Nz[2], Nz[1], Nx[2], &alpha, y, sy_1, x, sx_2, &beta, z, sz_1); break;
case 0x010: status = cublasSgemm(handle, T, N, Nz[2], Nz[1], Nx[2], &alpha, y, sy_2, x, sx_1, &beta, z, sz_1); break;
case 0x110: status = cublasSgemm(handle, T, T, Nz[2], Nz[1], Nx[2], &alpha, y, sy_2, x, sx_2, &beta, z, sz_1); break;
case 0x001: status = cublasSgemm(handle, T, T, Nz[1], Nz[2], Nx[2], &alpha, x, sx_1, y, sy_1, &beta, z, sz_2); break;
case 0x101: status = cublasSgemm(handle, N, T, Nz[1], Nz[2], Nx[2], &alpha, x, sx_2, y, sy_1, &beta, z, sz_2); break;
case 0x011: status = cublasSgemm(handle, T, N, Nz[1], Nz[2], Nx[2], &alpha, x, sx_1, y, sy_2, &beta, z, sz_2); break;
case 0x111: status = cublasSgemm(handle, N, N, Nz[1], Nz[2], Nx[2], &alpha, x, sx_2, y, sy_2, &beta, z, sz_2); break;
default: PyErr_Format(PyExc_ValueError, "some matrix has no unit stride (unit=%%x)", unit); %(fail)s;
}
if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"cublasSgemm failed (%%i) %%s\\n"
" unit=%%x N=%%d,"
" x shape=[%%d %%d %%d], y shape=[%%d %%d %%d], z shape=[%%d %%d %%d]"
" x strides=[%%d %%d %%d], y strides=[%%d %%d %%d], z strides=[%%d %%d %%d]",
status, cublasGetErrorString(status), unit, N,
Nx[0], Nx[1], Nx[2], Sx[0], Sx[1], Sx[2],
Ny[0], Ny[1], Ny[2], Sy[0], Sy[1], Sy[2],
Nz[0], Nz[1], Nz[2], Sz[0], Sz[1], Sz[2]);
%(fail)s;
}
x += Sx[0]; y += Sy[0]; z += Sz[0];
};
CNDA_THREAD_SYNC;
for(int i = 0; i < N_STREAMS; i++) {
cudaStreamSynchronize(streams[i]);
cudaStreamDestroy(streams[i]);
}
}
""") % locals()
def c_support_code(self): def c_support_code(self):
return """ return """
...@@ -210,7 +333,7 @@ class GpuBatchedDot(GpuOp): ...@@ -210,7 +333,7 @@ class GpuBatchedDot(GpuOp):
return rval return rval
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论