提交 738eae2e authored 作者: Frederic's avatar Frederic

Make the CudaNdarray += CudaNdarray automatically add broadcasted dimensions.

This is needed for GpuAdvancedIncSubetensor1
上级 ee76025c
...@@ -1002,21 +1002,40 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1002,21 +1002,40 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
self->nd, other->nd); self->nd, other->nd);
} }
//standard elemwise size checks //standard elemwise nb dim checks
if (self->nd != other->nd) if (self->nd < other->nd)
{ {
PyErr_Format( PyErr_Format(
PyExc_TypeError, PyExc_TypeError,
"CudaNdarray_inplace_elemwise: need same number of dims. Got %d and %d", "CudaNdarray_inplace_elemwise: The destination need more or the"
" same number of dimensions then the source. Got %d and %d.",
self->nd, other->nd); self->nd, other->nd);
return -1; return -1;
} }
//broadcast to the same number of dimensions.
int other_dims[self->nd];
int other_strides[self->nd];
int added_dims = self->nd - other->nd;
// Add the added broadcasted dimensions
for (int i = 0; i< added_dims; ++i)
{
other_dims[i] = 1;
other_strides[i] = 0;
}
// Copy the existing dimensions
for (int i = 0; i< other->nd; ++i)
{
other_dims[i+added_dims] = CudaNdarray_HOST_DIMS(other)[i];
other_strides[i+added_dims] = CudaNdarray_HOST_STRIDES(other)[i];
}
//standard elemwise dim checks //standard elemwise dim checks
unsigned int size = 1; unsigned int size = 1;
for (int i = 0; i< self->nd; ++i) for (int i = 0; i< self->nd; ++i)
{ {
if ((CudaNdarray_HOST_DIMS(self)[i] != CudaNdarray_HOST_DIMS(other)[i]) if ((CudaNdarray_HOST_DIMS(self)[i] != other_dims[i])
&& (CudaNdarray_HOST_DIMS(other)[i] != 1)) && (other_dims[i] != 1))
{ {
PyErr_SetString( PyErr_SetString(
PyExc_ValueError, PyExc_ValueError,
...@@ -1024,8 +1043,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1024,8 +1043,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
return -1; return -1;
} }
// if we're broadcasting other, then make sure it has stride 0 // if we're broadcasting other, then make sure it has stride 0
assert ((CudaNdarray_HOST_DIMS(self)[i] == CudaNdarray_HOST_DIMS(other)[i]) assert ((CudaNdarray_HOST_DIMS(self)[i] == other_dims[i])
|| (CudaNdarray_HOST_STRIDES(other)[i] == 0)); || (other_strides[i] == 0));
size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i]; size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i];
} }
...@@ -1090,7 +1109,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1090,7 +1109,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_DEV_DATA(other), CudaNdarray_DEV_DATA(other),
1, //strides 1, //strides
1, 1,
CudaNdarray_HOST_STRIDES(other)[0]); other_strides[0]);
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) if (cudaSuccess != err)
...@@ -1126,8 +1145,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1126,8 +1145,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_HOST_STRIDES(self)[1], CudaNdarray_HOST_STRIDES(self)[1],
CudaNdarray_DEV_DATA(other), CudaNdarray_DEV_DATA(other),
1, 1,
CudaNdarray_HOST_STRIDES(other)[0], other_strides[0],
CudaNdarray_HOST_STRIDES(other)[1]); other_strides[1]);
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) if (cudaSuccess != err)
...@@ -1165,9 +1184,9 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1165,9 +1184,9 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_HOST_STRIDES(self)[1], CudaNdarray_HOST_STRIDES(self)[1],
CudaNdarray_HOST_STRIDES(self)[2], CudaNdarray_HOST_STRIDES(self)[2],
CudaNdarray_DEV_DATA(other), CudaNdarray_DEV_DATA(other),
CudaNdarray_HOST_STRIDES(other)[0], other_strides[0],
CudaNdarray_HOST_STRIDES(other)[1], other_strides[1],
CudaNdarray_HOST_STRIDES(other)[2]); other_strides[2]);
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) if (cudaSuccess != err)
...@@ -1208,10 +1227,10 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1208,10 +1227,10 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_HOST_STRIDES(self)[2], CudaNdarray_HOST_STRIDES(self)[2],
CudaNdarray_HOST_STRIDES(self)[3], CudaNdarray_HOST_STRIDES(self)[3],
CudaNdarray_DEV_DATA(other), CudaNdarray_DEV_DATA(other),
CudaNdarray_HOST_STRIDES(other)[0], other_strides[0],
CudaNdarray_HOST_STRIDES(other)[1], other_strides[1],
CudaNdarray_HOST_STRIDES(other)[2], other_strides[2],
CudaNdarray_HOST_STRIDES(other)[3]); other_strides[3]);
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) if (cudaSuccess != err)
...@@ -1252,11 +1271,11 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1252,11 +1271,11 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_HOST_STRIDES(self)[2], CudaNdarray_HOST_STRIDES(self)[2],
CudaNdarray_HOST_STRIDES(self)[3], CudaNdarray_HOST_STRIDES(self)[3],
CudaNdarray_HOST_STRIDES(self)[4], CudaNdarray_HOST_STRIDES(self)[4],
CudaNdarray_DEV_DATA(other) + i * CudaNdarray_HOST_STRIDES(other)[0], CudaNdarray_DEV_DATA(other) + i * other_strides[0],
CudaNdarray_HOST_STRIDES(other)[1], other_strides[1],
CudaNdarray_HOST_STRIDES(other)[2], other_strides[2],
CudaNdarray_HOST_STRIDES(other)[3], other_strides[3],
CudaNdarray_HOST_STRIDES(other)[4]); other_strides[4]);
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if( cudaSuccess != err) if( cudaSuccess != err)
...@@ -1280,6 +1299,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1280,6 +1299,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
return -1; return -1;
} }
} }
if (verbose)
fprintf(stderr, "INPLACE ADD/DIV end\n");
return 0; return 0;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论