提交 dee8d3e6 authored 作者: Vivek Kulkarni's avatar Vivek Kulkarni

C code for GPUAdvancedIncSubtensor1.Handles one case to make it faster

上级 4e5e642e
...@@ -2442,6 +2442,50 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2442,6 +2442,50 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
x[i] += y x[i] += y
out[0] = x out[0] = x
def c_code(self, node, name, inputs, outputs, sub):
if (self.set_instead_of_inc) or \
(node.inputs[0].ndim != node.inputs[1].ndim):
raise NotImplementedError("This case does not have C code yet.")
x = inputs[0]
y = inputs[1]
ind = inputs[2]
out = outputs[0]
fail = sub['fail']
inplace = int(self.inplace)
return """
PyObject *x_obj, *y_obj, *row_x, *row_y;
PyObject *x_rowind_obj, *y_rowind_obj;
int *p_index;
int num_indices, j;
Py_XDECREF(%(out)s);
if (!%(inplace)s) {
%(out)s = (CudaNdarray*)CudaNdarray_Copy(%(x)s);
} else {
%(out)s = %(x)s;
Py_XINCREF(%(out)s);
}
x_obj = (PyObject*)CudaNdarray_View(%(x)s);
y_obj = (PyObject*)CudaNdarray_View(%(y)s);
num_indices = PyArray_SIZE(%(ind)s);
for (j = 0;j < num_indices; j++) {
p_index = (int *)PyArray_GETPTR1(%(ind)s, j);
x_rowind_obj = PyInt_FromLong(*p_index);
y_rowind_obj = PyInt_FromLong(j);
row_x = CudaNdarray_Subscript(x_obj, x_rowind_obj);
row_y = CudaNdarray_Subscript(y_obj, y_rowind_obj);
CudaNdarray_inplace_add(row_x, row_y);
}
if (!%(out)s) {
%(fail)s
}
""" %locals()
class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
""" """
......
...@@ -1889,7 +1889,7 @@ CudaNdarray_len(PyObject * py_self) ...@@ -1889,7 +1889,7 @@ CudaNdarray_len(PyObject * py_self)
} }
// Will by called by __getitem__ in Python // Will by called by __getitem__ in Python
static PyObject * PyObject *
CudaNdarray_Subscript(PyObject * py_self, PyObject * key) CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
{ {
int verbose = 0; int verbose = 0;
......
...@@ -479,7 +479,7 @@ int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self); ...@@ -479,7 +479,7 @@ int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self);
PyObject * CudaNdarray_View(const CudaNdarray * self); PyObject * CudaNdarray_View(const CudaNdarray * self);
PyObject * CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other); PyObject * CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other);
PyObject * CudaNdarray_Subscript(PyObject * py_self, PyObject * key);
// Ensures that *arr is a pointer to a contiguous ndarray of the specified // Ensures that *arr is a pointer to a contiguous ndarray of the specified
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论