提交 1519d758 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add C code to SparseBlockGemvSS

上级 0bc12fe9
......@@ -124,6 +124,69 @@ class SparseBlockGemvSS(GpuOp):
out[0] = o
def c_code(self, node, inputs, outputs, sub):
o, W, h, inputIdx, outputIdx = inputs
out = outputs[0]
res = None
if self.inplace:
res = """
Py_XDECREF(%(out)s);
%(out)s = %(o);
Py_INCREF(%(out)s);
""" % dict(out=out, o=o)
else:
res = """
if (CudaNdarray_prep_output(%(out)s, 2, CudaNdarray_HOST_DIMS(%(o)s)))
{
PyErr_SetString(PyExc_RuntimeError, "Cannot allocate output");
%(fail)s
}
if (CudaNdarray_CopyFromCudaNdarray(%(out)s, %(o)s)) {
PyErr_SetString(PyExc_RuntimeError, "Cannot copy data to output");
%(fail)s
}
""" % dict(out=out, o=o, fail=sub['fail'])
return res + """
CudaNdarray *W_part = CudaNdarray_new_nd(2);
CudaNdarray *h_part = CudaNdarray_new_nd(1);
CudaNdarray *out_part = CudaNdarray_new_nd(1);
if (W_part == NULL || h_part == NULL || o_part == NULL) {
Py_XDECREF(W_part);
Py_XDECREF(h_part);
Py_XDECREF(out_part);
}
CudaNdarray_set_dim(W_part, 0, CudaNdarray_HOST_DIMS(%(W)s)[2]);
CudaNdarray_set_stride(W_part, 0, CudaNdarray_HOST_STRIDES(%(W)s)[2]);
CudaNdarray_set_dim(W_part, 1, CudaNdarray_HOST_DIMS(%(W)s)[3]);
CudaNdarray_set_stride(W_part, 1, CudaNdarray_HOST_STRIDES(%(W)s)[3]);
CudaNdarray_set_dim(h_part, 0, CudaNdarray_HOST_DIMS(%(h)s)[1]);
CudaNdarray_set_stride(h_part, 0, CudaNdarray_HOST_STRIDES(%(h)s)[1]);
CudaNdarray_set_dim(out_part, 0, CudaNdarray_HOST_DIMS(%(out)s)[1]);
CudaNdarray_set_stride(out_part, 0, CudaNdarray_HOST_STRIDES(%(out)s)[1]);
for (int j = 0; j < CudaNdarray_HOST_DIMS(%(o)s)[0]; j++) {
npy_intp out_id = *(dtype_%(outputIdx)s *)PyArray_GETPTR1(%(outputIdx)s, j);
CudaNdarray_set_device_data(out_part, CudaNdarray_DEV_DATA(%(out)s) +
CudaNdarray_HOST_STRIDES(%(out)s)[0] * j, %(out)s);
for (int i = 0; i < CudaNdarray_HOST_DIMS(%(h)s)[0]; i++) {
npy_intp inp_id = *(dtype_%(inputIdx)s *)PyArray_GETPTR1(%(inputIdx)s, i);
CudaNdarray_set_device_data(h_part, CudaNdarray_DEV_DATA(%(h)s) +
CudaNdarray_HOST_STRIDES(%(h)s)[0] * i, %(h)s);
CudaNdarray_set_device_data(W_part, CudaNdarray_DEV_DATA(%(W)s) +
(CudaNdarray_HOST_STRIDES(%(W)s)[0] * inp_id) +
(CudaNdarray_HOST_STRIDES(%(W)s)[1] * out_id), %(W)s);
if (CudaNdarray_sgemv(1.0f, W_part, h_part, 1.0f, o_part)) {
%(fail)s
}
}
}
""" % dict(out=out, h=h, o=o, inputIdx=inputIdx, outputIdx=outputIdx,
W=W, fail=sub['fail'])
def grad(self, inputs, grads):
o, W, h, inputIdx, outputIdx = inputs
go = grads[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论