提交 57865538 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add C code using gemmBatched to SparseBlockDotOuterSS (the gradient).

上级 98a15fa1
...@@ -400,6 +400,84 @@ class SparseBlockOuterSS(GpuOp): ...@@ -400,6 +400,84 @@ class SparseBlockOuterSS(GpuOp):
out[0] = o out[0] = o
def c_support_code(self):
return """
__global__ void
SparseBlockOuter_fill_lists(
int n,
const float **x_list,
const float **y_list,
float **out_list,
const float *x, int x_str_0,
const float *y, int y_str_0,
float *out, int o_str_0, int o_str_1,
const npy_intp *xIdx,
const npy_intp *yIdx
) {
int i = threadIdx.x + blockDim.x * blockIdx.x;
int j = threadIdx.y + blockDim.y * blockIdx.y;
int p = i + j * blockDim.x * gridDim.x;
if (p >= n) return;
x_list[p] = &x[i * x_str_0];
y_list[p] = &y[j * y_str_0];
out_list[p] = &out[xIdx[i] * o_str_0 + yIdx[j] * o_str_1];
}
static int SparseBlockOuter_copy(PyArrayObject *a, npy_intp *b) {
cudaError_t err;
PyArrayObject *aa = (PyArrayObject *)PyArray_Cast(a, NPY_INTP);
if (aa == NULL) { return -1; }
err = cudaMemcpy(b, PyArray_DATA(aa), PyArray_NBYTES(aa),
cudaMemcpyHostToDevice);
Py_DECREF(aa);
if (err != cudaSuccess) {
PyErr_SetString(PyExc_RuntimeError, "Cannot copy index data to GPU");
return -1;
}
return 0;
}
"""
def c_support_code_apply(self, node, name):
return """
/* statics are initialized with 0 */
static float **%(n)s_out_list;
static const float **%(n)s_x_list;
static const float **%(n)s_y_list;
static size_t %(n)s_list_len;
static npy_intp *%(n)s_xIdx;
static size_t %(n)s_xIdx_len;
static npy_intp *%(n)s_yIdx;
static size_t %(n)s_yIdx_len;
// This is batch-ready
static int %(n)s_prep(int b, int i, int j) {
int s = b*i*j;
if (%(n)s_list_len < s) {
cudaFree(%(n)s_x_list);
cudaFree(%(n)s_y_list);
cudaFree(%(n)s_out_list);
if (cudaMalloc(&%(n)s_x_list, s*sizeof(float *)) != cudaSuccess) return -1;
if (cudaMalloc(&%(n)s_y_list, s*sizeof(float *)) != cudaSuccess) return -1;
if (cudaMalloc(&%(n)s_out_list, s*sizeof(float *)) != cudaSuccess) return -1;
%(n)s_list_len = s;
}
if (%(n)s_xIdx_len < b*i) {
cudaFree(%(n)s_xIdx);
if (cudaMalloc(&%(n)s_xIdx, b*i*sizeof(npy_intp)) != cudaSuccess)
return -1;
%(n)s_xIdx_len = b*i;
}
if (%(n)s_yIdx_len < b*j) {
cudaFree(%(n)s_yIdx);
if (cudaMalloc(&%(n)s_yIdx, b*j*sizeof(npy_intp)) != cudaSuccess)
return -1;
%(n)s_yIdx_len = b*j;
}
return 0;
}
""" % dict(n=name)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
o, x, y, xIdx, yIdx = inputs o, x, y, xIdx, yIdx = inputs
out = outputs[0] out = outputs[0]
...@@ -422,48 +500,51 @@ if (CudaNdarray_CopyFromCudaNdarray(%(out)s, %(o)s)) { ...@@ -422,48 +500,51 @@ if (CudaNdarray_CopyFromCudaNdarray(%(out)s, %(o)s)) {
} }
""" % dict(out=out, o=o, fail=sub['fail']) """ % dict(out=out, o=o, fail=sub['fail'])
return res + """{ return res + """
CudaNdarray *x_part = (CudaNdarray *)CudaNdarray_new_nd(1); if (%(name)s_prep(1, CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray *y_part = (CudaNdarray *)CudaNdarray_new_nd(1); CudaNdarray_HOST_DIMS(%(y)s)[0]) == -1) {
CudaNdarray *out_part = (CudaNdarray *)CudaNdarray_new_nd(2); PyErr_SetString(PyExc_RuntimeError, "Could not allocate working memory.");
if (x_part == NULL || y_part == NULL || out_part == NULL) { %(fail)s
Py_XDECREF(x_part);
Py_XDECREF(y_part);
Py_XDECREF(out_part);
} }
CudaNdarray_set_dim(x_part, 0, CudaNdarray_HOST_DIMS(%(x)s)[1]); if (SparseBlockOuter_copy(%(xIdx)s, %(name)s_xIdx) == -1)
CudaNdarray_set_stride(x_part, 0, CudaNdarray_HOST_STRIDES(%(x)s)[1]); { %(fail)s }
CudaNdarray_set_dim(y_part, 0, CudaNdarray_HOST_DIMS(%(y)s)[1]); if (SparseBlockOuter_copy(%(yIdx)s, %(name)s_yIdx) == -1)
CudaNdarray_set_stride(y_part, 0, CudaNdarray_HOST_STRIDES(%(y)s)[1]); { %(fail)s }
CudaNdarray_set_dim(out_part, 0, CudaNdarray_HOST_DIMS(%(out)s)[2]); {
CudaNdarray_set_stride(out_part, 0, CudaNdarray_HOST_STRIDES(%(out)s)[2]); dim3 block;
CudaNdarray_set_dim(out_part, 1, CudaNdarray_HOST_DIMS(%(out)s)[3]); block.x = CudaNdarray_HOST_DIMS(%(x)s)[0];
CudaNdarray_set_stride(out_part, 1, CudaNdarray_HOST_STRIDES(%(out)s)[3]); block.y = CudaNdarray_HOST_DIMS(%(y)s)[0];
SparseBlockOuter_fill_lists<<<block, 1>>>(
for (int j = 0; j < CudaNdarray_HOST_DIMS(%(y)s)[0]; j++) { block.x * block.y,
npy_intp y_id = *(dtype_%(xIdx)s *)PyArray_GETPTR1(%(yIdx)s, j); %(name)s_x_list,
CudaNdarray_set_device_data(y_part, CudaNdarray_DEV_DATA(%(y)s) + %(name)s_y_list,
CudaNdarray_HOST_STRIDES(%(y)s)[0] * j, %(y)s); %(name)s_out_list,
for (int i = 0; i < CudaNdarray_HOST_DIMS(%(x)s)[0]; i++) { CudaNdarray_DEV_DATA(%(x)s), CudaNdarray_HOST_STRIDES(%(x)s)[0],
npy_intp x_id = *(dtype_%(xIdx)s *)PyArray_GETPTR1(%(xIdx)s, i); CudaNdarray_DEV_DATA(%(y)s), CudaNdarray_HOST_STRIDES(%(y)s)[0],
CudaNdarray_set_device_data(x_part, CudaNdarray_DEV_DATA(%(x)s) + CudaNdarray_DEV_DATA(%(out)s),
CudaNdarray_HOST_STRIDES(%(x)s)[0] * i, %(x)s); CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
CudaNdarray_set_device_data(out_part, CudaNdarray_DEV_DATA(%(out)s) + %(name)s_xIdx,
(CudaNdarray_HOST_STRIDES(%(out)s)[0] * x_id) + %(name)s_yIdx);
(CudaNdarray_HOST_STRIDES(%(out)s)[1] * y_id), %(out)s);
if (CudaNdarray_sger(1.0f, x_part, y_part, out_part)) {
%(fail)s
}
}
} }
Py_DECREF(x_part); {
Py_DECREF(y_part); cublasStatus_t err;
Py_DECREF(out_part); float alpha = 1.0f;
}""" % dict(x=x, y=y, out=out, xIdx=xIdx, yIdx=yIdx, fail=sub['fail']) float beta = 1.0f;
err = cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_T,
CudaNdarray_HOST_DIMS(%(y)s)[1], CudaNdarray_HOST_DIMS(%(x)s)[1], 1,
&alpha, %(name)s_y_list, CudaNdarray_HOST_STRIDES(%(y)s)[0],
%(name)s_x_list, CudaNdarray_HOST_STRIDES(%(x)s)[0], &beta,
%(name)s_out_list, CudaNdarray_HOST_STRIDES(%(out)s)[2],
CudaNdarray_HOST_DIMS(%(x)s)[0] * CudaNdarray_HOST_DIMS(%(y)s)[0]);
if (err != CUBLAS_STATUS_SUCCESS) {
PyErr_SetString(PyExc_RuntimeError, "SgemmBatched failed");
%(fail)s
}
}""" % dict(x=x, y=y, out=out, xIdx=xIdx, yIdx=yIdx, name=name,
fail=sub['fail'])
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (1,)
sparse_block_outer_ss = SparseBlockOuterSS(False) sparse_block_outer_ss = SparseBlockOuterSS(False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论