提交 08857dc5 authored 作者: abergeron's avatar abergeron

Merge pull request #4066 from cooijmanstim/big_batched_dot

GpuBatchedDot: streams implementation (WIP)
...@@ -16,7 +16,10 @@ from theano.tensor import as_tensor_variable ...@@ -16,7 +16,10 @@ from theano.tensor import as_tensor_variable
class GpuBatchedDot(GpuOp): class GpuBatchedDot(GpuOp):
__props__ = () __props__ = ("stream_threshold",)
def __init__(self, stream_threshold=650):
self.stream_threshold = stream_threshold
def make_node(self, inp1, inp2): def make_node(self, inp1, inp2):
inp1 = gpu_contiguous(as_cuda_ndarray_variable(inp1)) inp1 = gpu_contiguous(as_cuda_ndarray_variable(inp1))
...@@ -39,79 +42,83 @@ class GpuBatchedDot(GpuOp): ...@@ -39,79 +42,83 @@ class GpuBatchedDot(GpuOp):
bx, by = input_names bx, by = input_names
bz, = output_names bz, = output_names
fail = sub['fail'] fail = sub['fail']
return """ threshold = self.stream_threshold
float alpha = 1.0; return ("""
float beta = 0.0; float alpha = 1.0, beta = 0.0;
int i, x_dim0, x_dim1, x_dim2, y_dim0, y_dim1, y_dim2;
int x_stride, y_stride, z_stride, total_size;
int ptr_array_size = 3 * CudaNdarray_HOST_DIMS(%(bx)s)[0] * sizeof(float *);
int out_dim[3];
cublasStatus_t err;
cudaError_t err1;
float **host_x = NULL;
float **host_z = NULL;
float **host_y = NULL;
float **gpu_x = NULL;
float **gpu_y = NULL;
float **gpu_z = NULL;
x_dim0 = CudaNdarray_HOST_DIMS(%(bx)s)[0]; const int* Nx = CudaNdarray_HOST_DIMS(%(bx)s);
x_dim1 = CudaNdarray_HOST_DIMS(%(bx)s)[1]; const int* Ny = CudaNdarray_HOST_DIMS(%(by)s);
x_dim2 = CudaNdarray_HOST_DIMS(%(bx)s)[2]; int Nz[3] = {0};
y_dim0 = CudaNdarray_HOST_DIMS(%(by)s)[0]; // use parallel cublasSgemm calls rather than cublasSgemmBatched for large products
y_dim1 = CudaNdarray_HOST_DIMS(%(by)s)[1]; // (compute products in double because they can be large and we don't need to be exact)
y_dim2 = CudaNdarray_HOST_DIMS(%(by)s)[2]; bool use_cublas_sgemm_batched = (
double(Nx[1]) * double(Nx[2]) * double(Ny[2]) <
double(%(threshold)s) * double(%(threshold)s) * double(%(threshold)s));
if (x_dim0 != y_dim0) if (Nx[0] != Ny[0]) {
{
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"The batchsizes (%%d, %%d) don't match.\\n", "The batchsizes (%%d, %%d) don't match.\\n",
x_dim0, x_dim1); Nx[0], Ny[0]);
%(fail)s; %(fail)s;
} }
if (x_dim2 != y_dim1) if (Nx[2] != Ny[1]) {
{
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"Shape mismatch. (%%d, %%d, %%d) (%%d, %%d, %%d)\\n", "Shape mismatch. (%%d, %%d, %%d) (%%d, %%d, %%d)\\n",
x_dim0, x_dim1, x_dim2, y_dim0, y_dim1, y_dim2); Nx[0], Nx[1], Nx[2], Ny[0], Ny[1], Ny[2]);
%(fail)s; %(fail)s;
} }
out_dim[0] = x_dim0; Nz[0] = Nx[0];
out_dim[1] = x_dim1; Nz[1] = Nx[1];
out_dim[2] = y_dim2; Nz[2] = Ny[2];
if ( !(%(bz)s if ( !(%(bz)s
&& %(bz)s->nd==3 && %(bz)s->nd==3
&& CudaNdarray_is_c_contiguous(%(bz)s) && CudaNdarray_is_c_contiguous(%(bz)s)
&& CudaNdarray_HOST_DIMS(%(bz)s)[0]==out_dim[0] && CudaNdarray_HOST_DIMS(%(bz)s)[0] == Nz[0]
&& CudaNdarray_HOST_DIMS(%(bz)s)[1]==out_dim[1] && CudaNdarray_HOST_DIMS(%(bz)s)[1] == Nz[1]
&& CudaNdarray_HOST_DIMS(%(bz)s)[2]==out_dim[2])) && CudaNdarray_HOST_DIMS(%(bz)s)[2] == Nz[2]))
{ {
Py_XDECREF(%(bz)s); Py_XDECREF(%(bz)s);
%(bz)s = (CudaNdarray*)CudaNdarray_NewDims(3,out_dim); %(bz)s = (CudaNdarray*)CudaNdarray_NewDims(3, Nz);
if (NULL == %(bz)s) if (NULL == %(bz)s) {
{
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"Failed to allocate output of %%d x %%d x %%d", "Failed to allocate output of %%d x %%d x %%d",
out_dim[0], out_dim[1], out_dim[2]); Nz[0], Nz[1], Nz[2]);
%(fail)s; %(fail)s;
} }
} }
if (x_dim0 != 0 && y_dim0 != 0 && if (Nx[0] == 0 || Nx[1] == 0 || Nx[2] == 0 ||
x_dim1 != 0 && y_dim1 != 0 && Ny[0] == 0 || Ny[1] == 0 || Ny[2] == 0)
x_dim2 != 0 && y_dim2 != 0)
{ {
x_stride = CudaNdarray_HOST_STRIDES(%(bx)s)[0]; const int total_size = Nz[0] * Nz[1] * Nz[2] * sizeof(float);
y_stride = CudaNdarray_HOST_STRIDES(%(by)s)[0]; if (cudaSuccess != cudaMemset(CudaNdarray_DEV_DATA(%(bz)s), 0, total_size))
z_stride = CudaNdarray_HOST_STRIDES(%(bz)s)[0]; {
PyErr_Format(PyExc_RuntimeError,
"Failed to fill output with zeros");
%(fail)s;
}
}
else if (use_cublas_sgemm_batched)
{
cublasStatus_t err;
cudaError_t err1;
float **host_x = NULL;
float **host_z = NULL;
float **host_y = NULL;
float **gpu_x = NULL;
float **gpu_y = NULL;
float **gpu_z = NULL;
const int ptr_array_size = 3 * Nx[0] * sizeof(float *);
const int x_stride = CudaNdarray_HOST_STRIDES(%(bx)s)[0];
const int y_stride = CudaNdarray_HOST_STRIDES(%(by)s)[0];
const int z_stride = CudaNdarray_HOST_STRIDES(%(bz)s)[0];
host_x = (float **) malloc (ptr_array_size); host_x = (float **) malloc (ptr_array_size);
...@@ -123,14 +130,14 @@ class GpuBatchedDot(GpuOp): ...@@ -123,14 +130,14 @@ class GpuBatchedDot(GpuOp):
%(fail)s; %(fail)s;
} }
host_y = &host_x[x_dim0]; host_y = &host_x[Nx[0]];
host_z = &host_y[x_dim0]; host_z = &host_y[Nx[0]];
host_x[0] = CudaNdarray_DEV_DATA(%(bx)s); host_x[0] = CudaNdarray_DEV_DATA(%(bx)s);
host_y[0] = CudaNdarray_DEV_DATA(%(by)s); host_y[0] = CudaNdarray_DEV_DATA(%(by)s);
host_z[0] = CudaNdarray_DEV_DATA(%(bz)s); host_z[0] = CudaNdarray_DEV_DATA(%(bz)s);
for (i = 1; i < out_dim[0]; i++) for (int i = 1; i < Nz[0]; i++)
{ {
host_x[i] = host_x[i - 1] + x_stride; host_x[i] = host_x[i - 1] + x_stride;
host_y[i] = host_y[i - 1] + y_stride; host_y[i] = host_y[i - 1] + y_stride;
...@@ -143,8 +150,8 @@ class GpuBatchedDot(GpuOp): ...@@ -143,8 +150,8 @@ class GpuBatchedDot(GpuOp):
%(fail)s; %(fail)s;
} }
gpu_y = &gpu_x[x_dim0]; gpu_y = &gpu_x[Nx[0]];
gpu_z = &gpu_y[x_dim0]; gpu_z = &gpu_y[Nx[0]];
err1 = cudaMemcpy(gpu_x, host_x, ptr_array_size, cudaMemcpyHostToDevice); err1 = cudaMemcpy(gpu_x, host_x, ptr_array_size, cudaMemcpyHostToDevice);
...@@ -157,13 +164,14 @@ class GpuBatchedDot(GpuOp): ...@@ -157,13 +164,14 @@ class GpuBatchedDot(GpuOp):
} }
err = cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, err = cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N,
y_dim2, x_dim1, x_dim2, &alpha, Ny[2], Nx[1], Nx[2], &alpha,
(const float **) gpu_y, y_dim2, (const float **) gpu_y, Ny[2],
(const float **) gpu_x, x_dim2, &beta, (const float **) gpu_x, Nx[2],
gpu_z, y_dim2, x_dim0); &beta, gpu_z, Ny[2], Nx[0]);
CLEANUP(); CNDA_THREAD_SYNC;
CLEANUP();
if (CUBLAS_STATUS_SUCCESS != err) if (CUBLAS_STATUS_SUCCESS != err)
{ {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
...@@ -171,19 +179,129 @@ class GpuBatchedDot(GpuOp): ...@@ -171,19 +179,129 @@ class GpuBatchedDot(GpuOp):
err, cublasGetErrorString(err)); err, cublasGetErrorString(err));
%(fail)s; %(fail)s;
} }
} } else {
else // copy inputs if not contiguous
{ """ +
total_size = x_dim0 * x_dim1 * y_dim2 * sizeof(float); ("\n".join("""
if (cudaSuccess != cudaMemset(CudaNdarray_DEV_DATA(%(bz)s), 0, total_size)) if (( CudaNdarray_HOST_DIMS(%(var)s)[0] > 1 && CudaNdarray_HOST_STRIDES(%(var)s)[0] != 1
&& CudaNdarray_HOST_DIMS(%(var)s)[1] > 1 && CudaNdarray_HOST_STRIDES(%(var)s)[1] != 1
&& CudaNdarray_HOST_DIMS(%(var)s)[2] > 1 && CudaNdarray_HOST_STRIDES(%(var)s)[2] != 1)
|| CudaNdarray_HOST_STRIDES(%(var)s)[0] < 0
|| CudaNdarray_HOST_STRIDES(%(var)s)[1] < 0
|| CudaNdarray_HOST_STRIDES(%(var)s)[2] < 0)
{
CudaNdarray *_copy = (CudaNdarray*) CudaNdarray_Copy(%(var)s);
if (!_copy)
%(fail)s;
Py_XDECREF(%(var)s);
%(var)s = _copy;
}
""" % dict(var=var, fail=fail) for var in (bx, by)))
+ """
// fail if the output is not contiguous; we can't copy it because we
// need to write to the original memory
if (( CudaNdarray_HOST_DIMS(%(bz)s)[0] > 1 && CudaNdarray_HOST_STRIDES(%(bz)s)[0] != 1
&& CudaNdarray_HOST_DIMS(%(bz)s)[1] > 1 && CudaNdarray_HOST_STRIDES(%(bz)s)[1] != 1
&& CudaNdarray_HOST_DIMS(%(bz)s)[2] > 1 && CudaNdarray_HOST_STRIDES(%(bz)s)[2] != 1)
|| CudaNdarray_HOST_STRIDES(%(bz)s)[0] < 0
|| CudaNdarray_HOST_STRIDES(%(bz)s)[1] < 0
|| CudaNdarray_HOST_STRIDES(%(bz)s)[2] < 0)
{ {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_AssertionError,
"Failed to fill output with zeros"); "non-unit or negative stride in output arg %(bz)s (%%i, %%i, %%i) of shape (%%i, %%i, %%i)",
CudaNdarray_HOST_STRIDES(%(bz)s)[0],
CudaNdarray_HOST_STRIDES(%(bz)s)[1],
CudaNdarray_HOST_STRIDES(%(bz)s)[2],
CudaNdarray_HOST_DIMS(%(bz)s)[0],
CudaNdarray_HOST_DIMS(%(bz)s)[1],
CudaNdarray_HOST_DIMS(%(bz)s)[2]);
%(fail)s; %(fail)s;
} }
}
""" % locals() const int* Sx = CudaNdarray_HOST_STRIDES(%(bx)s);
const int* Sy = CudaNdarray_HOST_STRIDES(%(by)s);
const int* Sz = CudaNdarray_HOST_STRIDES(%(bz)s);
/* encode the stride structure of _x,_y,_z into a single integer. */
int unit = 0;
unit |= ((Sx[2] == 1 || Nx[2] == 1) ? 0x0 : (Sx[1] == 1 || Nx[1] == 1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == 1 || Ny[2] == 1) ? 0x0 : (Sy[1] == 1 || Ny[1] == 1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == 1 || Nz[2] == 1) ? 0x0 : (Sz[1] == 1 || Nz[1] == 1) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int sx_1 = (Nx[1] > 1) ? Sx[1] : (Nx[2] + 1);
int sx_2 = (Nx[2] > 1) ? Sx[2] : (Nx[1] + 1);
int sy_1 = (Ny[1] > 1) ? Sy[1] : (Ny[2] + 1);
int sy_2 = (Ny[2] > 1) ? Sy[2] : (Ny[1] + 1);
int sz_1 = (Nz[1] > 1) ? Sz[1] : (Nz[2] + 1);
int sz_2 = (Nz[2] > 1) ? Sz[2] : (Nz[1] + 1);
cublasOperation_t N = CUBLAS_OP_N, T = CUBLAS_OP_T;
float* x = CudaNdarray_DEV_DATA(%(bx)s);
float* y = CudaNdarray_DEV_DATA(%(by)s);
float* z = CudaNdarray_DEV_DATA(%(bz)s);
float* xend = x + CudaNdarray_SIZE(%(bx)s);
float* yend = y + CudaNdarray_SIZE(%(by)s);
float* zend = z + CudaNdarray_SIZE(%(bz)s);
#define N_STREAMS 32
cudaStream_t streams[N_STREAMS];
for (int i = 0; i < N_STREAMS; i++) {
cudaStreamCreate(&streams[i]);
}
cudaStreamSynchronize(0);
for (int i = 0; i < Nx[0]; i++)
{
assert(CudaNdarray_DEV_DATA(%(bx)s) <= x); assert(x < CudaNdarray_DEV_DATA(%(bx)s) + CudaNdarray_SIZE(%(bx)s));
assert(CudaNdarray_DEV_DATA(%(by)s) <= y); assert(y < CudaNdarray_DEV_DATA(%(by)s) + CudaNdarray_SIZE(%(by)s));
assert(CudaNdarray_DEV_DATA(%(bz)s) <= z); assert(z < CudaNdarray_DEV_DATA(%(bz)s) + CudaNdarray_SIZE(%(bz)s));
cublasSetStream(handle, streams[i %% N_STREAMS]);
cublasStatus_t status;
switch(unit)
{
case 0x000: status = cublasSgemm(handle, N, N, Nz[2], Nz[1], Nx[2], &alpha, y, sy_1, x, sx_1, &beta, z, sz_1); break;
case 0x100: status = cublasSgemm(handle, N, T, Nz[2], Nz[1], Nx[2], &alpha, y, sy_1, x, sx_2, &beta, z, sz_1); break;
case 0x010: status = cublasSgemm(handle, T, N, Nz[2], Nz[1], Nx[2], &alpha, y, sy_2, x, sx_1, &beta, z, sz_1); break;
case 0x110: status = cublasSgemm(handle, T, T, Nz[2], Nz[1], Nx[2], &alpha, y, sy_2, x, sx_2, &beta, z, sz_1); break;
case 0x001: status = cublasSgemm(handle, T, T, Nz[1], Nz[2], Nx[2], &alpha, x, sx_1, y, sy_1, &beta, z, sz_2); break;
case 0x101: status = cublasSgemm(handle, N, T, Nz[1], Nz[2], Nx[2], &alpha, x, sx_2, y, sy_1, &beta, z, sz_2); break;
case 0x011: status = cublasSgemm(handle, T, N, Nz[1], Nz[2], Nx[2], &alpha, x, sx_1, y, sy_2, &beta, z, sz_2); break;
case 0x111: status = cublasSgemm(handle, N, N, Nz[1], Nz[2], Nx[2], &alpha, x, sx_2, y, sy_2, &beta, z, sz_2); break;
default: PyErr_Format(PyExc_ValueError, "some matrix has no unit stride (unit=%%x)", unit); %(fail)s;
}
if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"cublasSgemm failed (%%i) %%s\\n"
" unit=%%x N=%%d,"
" x shape=[%%d %%d %%d], y shape=[%%d %%d %%d], z shape=[%%d %%d %%d]"
" x strides=[%%d %%d %%d], y strides=[%%d %%d %%d], z strides=[%%d %%d %%d]",
status, cublasGetErrorString(status), unit, N,
Nx[0], Nx[1], Nx[2], Sx[0], Sx[1], Sx[2],
Ny[0], Ny[1], Ny[2], Sy[0], Sy[1], Sy[2],
Nz[0], Nz[1], Nz[2], Sz[0], Sz[1], Sz[2]);
%(fail)s;
}
x += Sx[0]; y += Sy[0]; z += Sz[0];
};
cublasSetStream(handle, NULL);
for (int i = 0; i < N_STREAMS; i++) {
cudaStreamSynchronize(streams[i]);
cudaStreamDestroy(streams[i]);
}
}
""") % locals()
def c_support_code(self): def c_support_code(self):
return """ return """
...@@ -199,8 +317,8 @@ class GpuBatchedDot(GpuOp): ...@@ -199,8 +317,8 @@ class GpuBatchedDot(GpuOp):
x, y = inp x, y = inp
gz, = grads gz, = grads
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) xgrad = GpuBatchedDot(stream_threshold=self.stream_threshold)(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) ygrad = GpuBatchedDot(stream_threshold=self.stream_threshold)(x.dimshuffle(0, 2, 1), gz)
rval = xgrad, ygrad rval = xgrad, ygrad
...@@ -210,7 +328,7 @@ class GpuBatchedDot(GpuOp): ...@@ -210,7 +328,7 @@ class GpuBatchedDot(GpuOp):
return rval return rval
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (3,)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
......
...@@ -48,45 +48,48 @@ class TestBatchedDot(unittest_tools.InferShapeTester): ...@@ -48,45 +48,48 @@ class TestBatchedDot(unittest_tools.InferShapeTester):
mode = mode_with_gpu mode = mode_with_gpu
def test_batched_dot_correctness(self): def test_batched_dot_correctness(self):
# test both implementations
for threshold in [0, 100]:
batched_dot = GpuBatchedDot(stream_threshold=threshold)
def cmp(a_shp, b_shp): def cmp(a_shp, b_shp):
a=numpy.random.randn(*a_shp).astype(numpy.float32) a=numpy.random.randn(*a_shp).astype(numpy.float32)
b=numpy.random.randn(*b_shp).astype(numpy.float32) b=numpy.random.randn(*b_shp).astype(numpy.float32)
x=tensor.ftensor3() x=tensor.ftensor3()
y=tensor.ftensor3() y=tensor.ftensor3()
f=theano.function([x,y], batched_dot(x,y), mode=mode_with_gpu) f=theano.function([x,y], batched_dot(x,y), mode=mode_with_gpu)
z0=numpy.asarray(f(a,b)) z0=numpy.asarray(f(a,b))
ga = cuda_ndarray.CudaNdarray(a) ga = cuda_ndarray.CudaNdarray(a)
gb = cuda_ndarray.CudaNdarray(b) gb = cuda_ndarray.CudaNdarray(b)
z1=numpy.asarray(f(ga,gb)) z1=numpy.asarray(f(ga,gb))
z_test = numpy.sum(a[:,:,:,None]*b[:,None,:,:],axis=-2) z_test = numpy.sum(a[:,:,:,None]*b[:,None,:,:],axis=-2)
unittest_tools.assert_allclose(z0, z_test) unittest_tools.assert_allclose(z0, z_test)
unittest_tools.assert_allclose(z1, z_test) unittest_tools.assert_allclose(z1, z_test)
cmp((5,4,3), (5,3,2)) cmp((5,4,3), (5,3,2))
cmp((5,3,3), (5,3,3)) cmp((5,3,3), (5,3,3))
cmp((5,2,6), (5,6,3)) cmp((5,2,6), (5,6,3))
# Test dimensions of 0 # Test dimensions of 0
cmp((0,2,6), (0,6,3)) cmp((0,2,6), (0,6,3))
cmp((5,0,3), (5,3,2)) cmp((5,0,3), (5,3,2))
cmp((5,4,0), (5,0,2)) cmp((5,4,0), (5,0,2))
cmp((5,4,3), (5,3,0)) cmp((5,4,3), (5,3,0))
cmp((0,0,0), (0,0,0)) cmp((0,0,0), (0,0,0))
# Test dimensions of 1 # Test dimensions of 1
cmp((1,2,6), (1,6,3)) cmp((1,2,6), (1,6,3))
cmp((5,1,3), (5,3,2)) cmp((5,1,3), (5,3,2))
cmp((5,4,1), (5,1,2)) cmp((5,4,1), (5,1,2))
cmp((5,4,3), (5,3,1)) cmp((5,4,3), (5,3,1))
def test_batched_dot_errors(self): def test_batched_dot_errors(self):
...@@ -109,11 +112,12 @@ class TestBatchedDot(unittest_tools.InferShapeTester): ...@@ -109,11 +112,12 @@ class TestBatchedDot(unittest_tools.InferShapeTester):
self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2)) self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2))
def test_batched_dot_gradient(self): def test_batched_dot_gradient(self):
unittest_tools.verify_grad( for threshold in [0, 100]:
batched_dot, unittest_tools.verify_grad(
[numpy.random.randn(5,7,2).astype(numpy.float32), GpuBatchedDot(stream_threshold=threshold),
numpy.random.randn(5,2,6).astype(numpy.float32)], [numpy.random.randn(5,7,2).astype(numpy.float32),
mode=mode_with_gpu) numpy.random.randn(5,2,6).astype(numpy.float32)],
mode=mode_with_gpu)
def test_infer_shape(self): def test_infer_shape(self):
# only matrix/matrix is supported # only matrix/matrix is supported
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论