提交 4b831a4a authored 作者: Vivek Kulkarni's avatar Vivek Kulkarni

Addressing code review comments.

上级 dee8d3e6
......@@ -2457,7 +2457,7 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
return """
PyObject *x_obj, *y_obj, *row_x, *row_y;
PyObject *x_rowind_obj, *y_rowind_obj;
int *p_index;
dtype_%(ind)s *p_index;
int num_indices, j;
Py_XDECREF(%(out)s);
......@@ -2474,13 +2474,22 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
for (j = 0;j < num_indices; j++) {
p_index = (int *)PyArray_GETPTR1(%(ind)s, j);
p_index = (dtype_%(ind)s *)PyArray_GETPTR1(%(ind)s, j);
x_rowind_obj = PyInt_FromLong(*p_index);
assert(PyInt_AsLong(x_rowind_obj) == (*p_index));
y_rowind_obj = PyInt_FromLong(j);
assert(PyInt_AsLong(y_rowind_obj) == j);
row_x = CudaNdarray_Subscript(x_obj, x_rowind_obj);
row_y = CudaNdarray_Subscript(y_obj, y_rowind_obj);
CudaNdarray_inplace_add(row_x, row_y);
CudaNdarray_inplace_elemwise(row_x, row_y, IADD);
Py_XDECREF(x_rowind_obj);
Py_XDECREF(y_rowind_obj);
}
Py_XDECREF(x_obj);
Py_XDECREF(y_obj);
if (!%(out)s) {
%(fail)s
......
......@@ -746,15 +746,6 @@ PyObject * CudaNdarray_View(const CudaNdarray * self)
return (PyObject*)rval;
}
enum operator_t
{
IADD=0,
IDIV,
CPY,
N_ELEMWISE_OPS // This is to know the number of operation
};
/*
* d0,... are the output dims
* indices are a list of index to operate on
......
......@@ -95,6 +95,15 @@ struct CudaNdarray
real* devdata; //pointer to data element [0,..,0].
};
enum operator_t
{
IADD=0,
IDIV,
CPY,
N_ELEMWISE_OPS // This is to know the number of operation
};
/*
* Return a CudaNdarray whose 'nd' dimensions are all 0.
* if nd==-1, it is not initialized.
......@@ -478,8 +487,8 @@ int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self);
PyObject * CudaNdarray_View(const CudaNdarray * self);
PyObject * CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other);
PyObject * CudaNdarray_Subscript(PyObject * py_self, PyObject * key);
int CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t fct_nb);
// Ensures that *arr is a pointer to a contiguous ndarray of the specified
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论