提交 738eae2e authored 作者: Frederic's avatar Frederic

Make the CudaNdarray += CudaNdarray automatically add broadcasted dimensions.

This is needed for GpuAdvancedIncSubetensor1
上级 ee76025c
......@@ -1002,21 +1002,40 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
self->nd, other->nd);
}
//standard elemwise size checks
if (self->nd != other->nd)
//standard elemwise nb dim checks
if (self->nd < other->nd)
{
PyErr_Format(
PyExc_TypeError,
"CudaNdarray_inplace_elemwise: need same number of dims. Got %d and %d",
"CudaNdarray_inplace_elemwise: The destination need more or the"
" same number of dimensions then the source. Got %d and %d.",
self->nd, other->nd);
return -1;
}
//broadcast to the same number of dimensions.
int other_dims[self->nd];
int other_strides[self->nd];
int added_dims = self->nd - other->nd;
// Add the added broadcasted dimensions
for (int i = 0; i< added_dims; ++i)
{
other_dims[i] = 1;
other_strides[i] = 0;
}
// Copy the existing dimensions
for (int i = 0; i< other->nd; ++i)
{
other_dims[i+added_dims] = CudaNdarray_HOST_DIMS(other)[i];
other_strides[i+added_dims] = CudaNdarray_HOST_STRIDES(other)[i];
}
//standard elemwise dim checks
unsigned int size = 1;
for (int i = 0; i< self->nd; ++i)
{
if ((CudaNdarray_HOST_DIMS(self)[i] != CudaNdarray_HOST_DIMS(other)[i])
&& (CudaNdarray_HOST_DIMS(other)[i] != 1))
if ((CudaNdarray_HOST_DIMS(self)[i] != other_dims[i])
&& (other_dims[i] != 1))
{
PyErr_SetString(
PyExc_ValueError,
......@@ -1024,8 +1043,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
return -1;
}
// if we're broadcasting other, then make sure it has stride 0
assert ((CudaNdarray_HOST_DIMS(self)[i] == CudaNdarray_HOST_DIMS(other)[i])
|| (CudaNdarray_HOST_STRIDES(other)[i] == 0));
assert ((CudaNdarray_HOST_DIMS(self)[i] == other_dims[i])
|| (other_strides[i] == 0));
size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i];
}
......@@ -1090,7 +1109,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_DEV_DATA(other),
1, //strides
1,
CudaNdarray_HOST_STRIDES(other)[0]);
other_strides[0]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
......@@ -1126,8 +1145,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_HOST_STRIDES(self)[1],
CudaNdarray_DEV_DATA(other),
1,
CudaNdarray_HOST_STRIDES(other)[0],
CudaNdarray_HOST_STRIDES(other)[1]);
other_strides[0],
other_strides[1]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
......@@ -1165,9 +1184,9 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_HOST_STRIDES(self)[1],
CudaNdarray_HOST_STRIDES(self)[2],
CudaNdarray_DEV_DATA(other),
CudaNdarray_HOST_STRIDES(other)[0],
CudaNdarray_HOST_STRIDES(other)[1],
CudaNdarray_HOST_STRIDES(other)[2]);
other_strides[0],
other_strides[1],
other_strides[2]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
......@@ -1208,10 +1227,10 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_HOST_STRIDES(self)[2],
CudaNdarray_HOST_STRIDES(self)[3],
CudaNdarray_DEV_DATA(other),
CudaNdarray_HOST_STRIDES(other)[0],
CudaNdarray_HOST_STRIDES(other)[1],
CudaNdarray_HOST_STRIDES(other)[2],
CudaNdarray_HOST_STRIDES(other)[3]);
other_strides[0],
other_strides[1],
other_strides[2],
other_strides[3]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
......@@ -1252,11 +1271,11 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
CudaNdarray_HOST_STRIDES(self)[2],
CudaNdarray_HOST_STRIDES(self)[3],
CudaNdarray_HOST_STRIDES(self)[4],
CudaNdarray_DEV_DATA(other) + i * CudaNdarray_HOST_STRIDES(other)[0],
CudaNdarray_HOST_STRIDES(other)[1],
CudaNdarray_HOST_STRIDES(other)[2],
CudaNdarray_HOST_STRIDES(other)[3],
CudaNdarray_HOST_STRIDES(other)[4]);
CudaNdarray_DEV_DATA(other) + i * other_strides[0],
other_strides[1],
other_strides[2],
other_strides[3],
other_strides[4]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
......@@ -1280,6 +1299,8 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
return -1;
}
}
if (verbose)
fprintf(stderr, "INPLACE ADD/DIV end\n");
return 0;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论