提交 19da7c2d authored 作者: Frederic's avatar Frederic

CudaNdarray_inplace_elemwise, accept input then CudaNdarray.

上级 b9bc0142
...@@ -663,7 +663,7 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape) ...@@ -663,7 +663,7 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape)
return (PyObject*)rval; return (PyObject*)rval;
} }
PyObject * CudaNdarray_View(CudaNdarray * self) PyObject * CudaNdarray_View(const CudaNdarray * self)
{ {
CudaNdarray * rval = (CudaNdarray*)CudaNdarray_New(self->nd); CudaNdarray * rval = (CudaNdarray*)CudaNdarray_New(self->nd);
if (!rval || CudaNdarray_set_device_data(rval, CudaNdarray_DEV_DATA(self), self)) if (!rval || CudaNdarray_set_device_data(rval, CudaNdarray_DEV_DATA(self), self))
...@@ -985,11 +985,19 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -985,11 +985,19 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
"CudaNdarray_inplace_elemwise need a CudaNdarray on left"); "CudaNdarray_inplace_elemwise need a CudaNdarray on left");
return -1; return -1;
} }
CudaNdarray * new_other = NULL;
if (!CudaNdarray_Check(py_other)) { if (!CudaNdarray_Check(py_other)) {
PyErr_SetString( new_other = (CudaNdarray*) CudaNdarray_New();
PyExc_TypeError, if(!new_other)
"CudaNdarray_inplace_elemwise need a CudaNdarray on right"); {
return -1; return -1;
}
if(CudaNdarray_CopyFromArray(new_other, (PyArrayObject *) py_other))
{
Py_XDECREF(new_other);
return -1;
}
py_other = (PyObject *) new_other;
} }
CudaNdarray * self = (CudaNdarray *)py_self; CudaNdarray * self = (CudaNdarray *)py_self;
...@@ -1010,6 +1018,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1010,6 +1018,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
"CudaNdarray_inplace_elemwise: The destination need more or the" "CudaNdarray_inplace_elemwise: The destination need more or the"
" same number of dimensions then the source. Got %d and %d.", " same number of dimensions then the source. Got %d and %d.",
self->nd, other->nd); self->nd, other->nd);
Py_XDECREF(new_other);
return -1; return -1;
} }
...@@ -1040,6 +1049,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1040,6 +1049,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
PyErr_SetString( PyErr_SetString(
PyExc_ValueError, PyExc_ValueError,
"CudaNdarray_inplace_elemwise need same dimensions (or broadcastable dimension)"); "CudaNdarray_inplace_elemwise need same dimensions (or broadcastable dimension)");
Py_XDECREF(new_other);
return -1; return -1;
} }
// if we're broadcasting other, then make sure it has stride 0 // if we're broadcasting other, then make sure it has stride 0
...@@ -1055,8 +1065,10 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1055,8 +1065,10 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
PyErr_SetString( PyErr_SetString(
PyExc_ValueError, PyExc_ValueError,
"CudaNdarray_inplace_elemwise cannot work inplace on an un-initialized array"); "CudaNdarray_inplace_elemwise cannot work inplace on an un-initialized array");
Py_XDECREF(new_other);
return 0; return 0;
} }
Py_XDECREF(new_other);
return 0; return 0;
} }
...@@ -1087,6 +1099,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1087,6 +1099,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
"Cuda error: %s: %s.\n", "Cuda error: %s: %s.\n",
"k3", "k3",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_XDECREF(new_other);
return -1; return -1;
} }
} }
...@@ -1119,6 +1132,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1119,6 +1132,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
"Cuda error: %s: %s.\n", "Cuda error: %s: %s.\n",
"k3", "k3",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_XDECREF(new_other);
return -1; return -1;
} }
} }
...@@ -1156,6 +1170,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1156,6 +1170,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
"Cuda error: %s: %s.\n", "Cuda error: %s: %s.\n",
"k3", "k3",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_XDECREF(new_other);
return -1; return -1;
} }
} }
...@@ -1196,6 +1211,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1196,6 +1211,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
"Cuda error: %s: %s.\n", "Cuda error: %s: %s.\n",
"k3", "k3",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_XDECREF(new_other);
return -1; return -1;
} }
} }
...@@ -1240,6 +1256,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1240,6 +1256,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
"Cuda error: %s: %s.\n", "Cuda error: %s: %s.\n",
"k4", "k4",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_XDECREF(new_other);
return -1; return -1;
} }
} }
...@@ -1285,6 +1302,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1285,6 +1302,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
"Cuda error: %s: %s.\n", "Cuda error: %s: %s.\n",
"k4", "k4",
cudaGetErrorString(err)); cudaGetErrorString(err));
Py_XDECREF(new_other);
return -1; return -1;
} }
} }
...@@ -1296,11 +1314,13 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1296,11 +1314,13 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
PyExc_NotImplementedError, PyExc_NotImplementedError,
"inplace_elemwise w nd=%i\n", "inplace_elemwise w nd=%i\n",
self->nd); self->nd);
Py_XDECREF(new_other);
return -1; return -1;
} }
} }
if (verbose) if (verbose)
fprintf(stderr, "INPLACE ADD/DIV end\n"); fprintf(stderr, "INPLACE ADD/DIV end\n");
Py_XDECREF(new_other);
return 0; return 0;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论