提交 4b831a4a authored 作者: Vivek Kulkarni's avatar Vivek Kulkarni

Addressing code review comments.

上级 dee8d3e6
...@@ -2457,7 +2457,7 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2457,7 +2457,7 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
return """ return """
PyObject *x_obj, *y_obj, *row_x, *row_y; PyObject *x_obj, *y_obj, *row_x, *row_y;
PyObject *x_rowind_obj, *y_rowind_obj; PyObject *x_rowind_obj, *y_rowind_obj;
int *p_index; dtype_%(ind)s *p_index;
int num_indices, j; int num_indices, j;
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
...@@ -2474,13 +2474,22 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2474,13 +2474,22 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
for (j = 0;j < num_indices; j++) { for (j = 0;j < num_indices; j++) {
p_index = (int *)PyArray_GETPTR1(%(ind)s, j); p_index = (dtype_%(ind)s *)PyArray_GETPTR1(%(ind)s, j);
x_rowind_obj = PyInt_FromLong(*p_index); x_rowind_obj = PyInt_FromLong(*p_index);
assert(PyInt_AsLong(x_rowind_obj) == (*p_index));
y_rowind_obj = PyInt_FromLong(j); y_rowind_obj = PyInt_FromLong(j);
assert(PyInt_AsLong(y_rowind_obj) == j);
row_x = CudaNdarray_Subscript(x_obj, x_rowind_obj); row_x = CudaNdarray_Subscript(x_obj, x_rowind_obj);
row_y = CudaNdarray_Subscript(y_obj, y_rowind_obj); row_y = CudaNdarray_Subscript(y_obj, y_rowind_obj);
CudaNdarray_inplace_add(row_x, row_y); CudaNdarray_inplace_elemwise(row_x, row_y, IADD);
Py_XDECREF(x_rowind_obj);
Py_XDECREF(y_rowind_obj);
} }
Py_XDECREF(x_obj);
Py_XDECREF(y_obj);
if (!%(out)s) { if (!%(out)s) {
%(fail)s %(fail)s
......
...@@ -746,15 +746,6 @@ PyObject * CudaNdarray_View(const CudaNdarray * self) ...@@ -746,15 +746,6 @@ PyObject * CudaNdarray_View(const CudaNdarray * self)
return (PyObject*)rval; return (PyObject*)rval;
} }
enum operator_t
{
IADD=0,
IDIV,
CPY,
N_ELEMWISE_OPS // This is to know the number of operation
};
/* /*
* d0,... are the output dims * d0,... are the output dims
* indices are a list of index to operate on * indices are a list of index to operate on
......
...@@ -95,6 +95,15 @@ struct CudaNdarray ...@@ -95,6 +95,15 @@ struct CudaNdarray
real* devdata; //pointer to data element [0,..,0]. real* devdata; //pointer to data element [0,..,0].
}; };
enum operator_t
{
IADD=0,
IDIV,
CPY,
N_ELEMWISE_OPS // This is to know the number of operation
};
/* /*
* Return a CudaNdarray whose 'nd' dimensions are all 0. * Return a CudaNdarray whose 'nd' dimensions are all 0.
* if nd==-1, it is not initialized. * if nd==-1, it is not initialized.
...@@ -478,8 +487,8 @@ int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self); ...@@ -478,8 +487,8 @@ int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self);
PyObject * CudaNdarray_View(const CudaNdarray * self); PyObject * CudaNdarray_View(const CudaNdarray * self);
PyObject * CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other);
PyObject * CudaNdarray_Subscript(PyObject * py_self, PyObject * key); PyObject * CudaNdarray_Subscript(PyObject * py_self, PyObject * key);
int CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t fct_nb);
// Ensures that *arr is a pointer to a contiguous ndarray of the specified // Ensures that *arr is a pointer to a contiguous ndarray of the specified
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论