提交 3de42af7 authored 作者: Frederic's avatar Frederic

We add dimensions to CudaNdarray to automatically broadcast before copy.

This affect all the function that call this inner function.
上级 697d3eb7
...@@ -2790,15 +2790,29 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self, ...@@ -2790,15 +2790,29 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
"can't copy into un-initialized CudaNdarray"); "can't copy into un-initialized CudaNdarray");
return -1; return -1;
} }
if (self->nd != other->nd) CudaNdarray * new_other = NULL;
if (self->nd < other->nd)
{ {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"CudaNdarray_CopyFromCudaNdarray: need same number of" "CudaNdarray_CopyFromCudaNdarray: The destination need more or the"
" dims. destination nd=%d, source nd=%d." " same number of dimensions then the source. Got %d and %d.",
" No broadcasting implemented.",
self->nd, other->nd); self->nd, other->nd);
return -1; return -1;
} }
else if (self->nd != other->nd)
{
CudaNdarray * new_other = (CudaNdarray *) CudaNdarray_View(other);
int added_dims = self->nd - other->nd;
int pattern[self->nd];
for(int i = 0; i < added_dims; i++)
pattern[i] = -1;
for(int i = 0; i < other->nd; i++)
pattern[i + added_dims] = i;
CudaNdarray_dimshuffle(new_other, self->nd, pattern);
other = new_other;
}
assert(self->nd == other->nd);
//standard elemwise dim checks (also compute total size) //standard elemwise dim checks (also compute total size)
unsigned int size = 1; unsigned int size = 1;
unsigned int size_source = 1; unsigned int size_source = 1;
...@@ -2812,13 +2826,15 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self, ...@@ -2812,13 +2826,15 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
" destination=%d, source=%d", " destination=%d, source=%d",
i, CudaNdarray_HOST_DIMS(self)[i], i, CudaNdarray_HOST_DIMS(self)[i],
CudaNdarray_HOST_DIMS(other)[i]); CudaNdarray_HOST_DIMS(other)[i]);
return -1; Py_XDECREF(new_other);
return -1;
} }
size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i]; size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i];
size_source *= (unsigned int) CudaNdarray_HOST_DIMS(other)[i]; size_source *= (unsigned int) CudaNdarray_HOST_DIMS(other)[i];
} }
if (0 == size) if (0 == size)
{ {
Py_XDECREF(new_other);
return 0; //nothing to copy, we're done. return 0; //nothing to copy, we're done.
} }
if (CudaNdarray_is_c_contiguous(self) && if (CudaNdarray_is_c_contiguous(self) &&
...@@ -2831,6 +2847,7 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self, ...@@ -2831,6 +2847,7 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
cublasScopy(size, CudaNdarray_DEV_DATA(other), 1, cublasScopy(size, CudaNdarray_DEV_DATA(other), 1,
CudaNdarray_DEV_DATA(self), 1); CudaNdarray_DEV_DATA(self), 1);
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
Py_XDECREF(new_other);
if (CUBLAS_STATUS_SUCCESS != cublasGetError()) if (CUBLAS_STATUS_SUCCESS != cublasGetError())
{ {
PyErr_SetString(PyExc_RuntimeError, "Error copying memory"); PyErr_SetString(PyExc_RuntimeError, "Error copying memory");
...@@ -2868,6 +2885,7 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self, ...@@ -2868,6 +2885,7 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
"Cuda error: %s: %s. (n_blocks=%i," "Cuda error: %s: %s. (n_blocks=%i,"
" n_threads_per_block=%i)\n", "k_copy_1d", " n_threads_per_block=%i)\n", "k_copy_1d",
cudaGetErrorString(err), n_blocks, n_threads); cudaGetErrorString(err), n_blocks, n_threads);
Py_XDECREF(new_other);
return -1; return -1;
} }
}; break; }; break;
...@@ -2912,10 +2930,12 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self, ...@@ -2912,10 +2930,12 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
"k_elemwise_unary_rowmajor_copy", "k_elemwise_unary_rowmajor_copy",
cudaGetErrorString(err), n_blocks, cudaGetErrorString(err), n_blocks,
threads_per_block); threads_per_block);
Py_XDECREF(new_other);
return -1; return -1;
} }
} }
}; };
Py_XDECREF(new_other);
return 0; return 0;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论