提交 29db8ffb authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix errors in C code and add a cache version. It passes the tests and works.

上级 1519d758
......@@ -124,7 +124,7 @@ class SparseBlockGemvSS(GpuOp):
out[0] = o
def c_code(self, node, inputs, outputs, sub):
def c_code(self, node, nodename, inputs, outputs, sub):
o, W, h, inputIdx, outputIdx = inputs
out = outputs[0]
......@@ -133,12 +133,12 @@ class SparseBlockGemvSS(GpuOp):
if self.inplace:
res = """
Py_XDECREF(%(out)s);
%(out)s = %(o);
%(out)s = %(o)s;
Py_INCREF(%(out)s);
""" % dict(out=out, o=o)
else:
res = """
if (CudaNdarray_prep_output(%(out)s, 2, CudaNdarray_HOST_DIMS(%(o)s)))
if (CudaNdarray_prep_output(&%(out)s, 2, CudaNdarray_HOST_DIMS(%(o)s)))
{
PyErr_SetString(PyExc_RuntimeError, "Cannot allocate output");
%(fail)s
......@@ -150,18 +150,19 @@ class SparseBlockGemvSS(GpuOp):
""" % dict(out=out, o=o, fail=sub['fail'])
return res + """
CudaNdarray *W_part = CudaNdarray_new_nd(2);
CudaNdarray *h_part = CudaNdarray_new_nd(1);
CudaNdarray *out_part = CudaNdarray_new_nd(1);
if (W_part == NULL || h_part == NULL || o_part == NULL) {
{
CudaNdarray *W_part = (CudaNdarray *)CudaNdarray_new_nd(2);
CudaNdarray *h_part = (CudaNdarray *)CudaNdarray_new_nd(1);
CudaNdarray *out_part = (CudaNdarray *)CudaNdarray_new_nd(1);
if (W_part == NULL || h_part == NULL || out_part == NULL) {
Py_XDECREF(W_part);
Py_XDECREF(h_part);
Py_XDECREF(out_part);
}
CudaNdarray_set_dim(W_part, 0, CudaNdarray_HOST_DIMS(%(W)s)[2]);
CudaNdarray_set_stride(W_part, 0, CudaNdarray_HOST_STRIDES(%(W)s)[2]);
CudaNdarray_set_dim(W_part, 1, CudaNdarray_HOST_DIMS(%(W)s)[3]);
CudaNdarray_set_stride(W_part, 1, CudaNdarray_HOST_STRIDES(%(W)s)[3]);
CudaNdarray_set_dim(W_part, 0, CudaNdarray_HOST_DIMS(%(W)s)[3]);
CudaNdarray_set_stride(W_part, 0, CudaNdarray_HOST_STRIDES(%(W)s)[3]);
CudaNdarray_set_dim(W_part, 1, CudaNdarray_HOST_DIMS(%(W)s)[2]);
CudaNdarray_set_stride(W_part, 1, CudaNdarray_HOST_STRIDES(%(W)s)[2]);
CudaNdarray_set_dim(h_part, 0, CudaNdarray_HOST_DIMS(%(h)s)[1]);
CudaNdarray_set_stride(h_part, 0, CudaNdarray_HOST_STRIDES(%(h)s)[1]);
CudaNdarray_set_dim(out_part, 0, CudaNdarray_HOST_DIMS(%(out)s)[1]);
......@@ -179,14 +180,18 @@ class SparseBlockGemvSS(GpuOp):
(CudaNdarray_HOST_STRIDES(%(W)s)[0] * inp_id) +
(CudaNdarray_HOST_STRIDES(%(W)s)[1] * out_id), %(W)s);
if (CudaNdarray_sgemv(1.0f, W_part, h_part, 1.0f, o_part)) {
if (CudaNdarray_sgemv(1.0f, W_part, h_part, 1.0f, out_part)) {
%(fail)s
}
}
}
}
""" % dict(out=out, h=h, o=o, inputIdx=inputIdx, outputIdx=outputIdx,
W=W, fail=sub['fail'])
def c_code_cache_version(self):
return (0,)
def grad(self, inputs, grads):
o, W, h, inputIdx, outputIdx = inputs
go = grads[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论