提交 3de42af7 authored 作者: Frederic's avatar Frederic

We add dimensions to CudaNdarray to automatically broadcast before copy.

This affect all the function that call this inner function.
上级 697d3eb7
......@@ -2790,15 +2790,29 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
"can't copy into un-initialized CudaNdarray");
return -1;
}
if (self->nd != other->nd)
CudaNdarray * new_other = NULL;
if (self->nd < other->nd)
{
PyErr_Format(PyExc_NotImplementedError,
"CudaNdarray_CopyFromCudaNdarray: need same number of"
" dims. destination nd=%d, source nd=%d."
" No broadcasting implemented.",
"CudaNdarray_CopyFromCudaNdarray: The destination need more or the"
" same number of dimensions then the source. Got %d and %d.",
self->nd, other->nd);
return -1;
}
else if (self->nd != other->nd)
{
CudaNdarray * new_other = (CudaNdarray *) CudaNdarray_View(other);
int added_dims = self->nd - other->nd;
int pattern[self->nd];
for(int i = 0; i < added_dims; i++)
pattern[i] = -1;
for(int i = 0; i < other->nd; i++)
pattern[i + added_dims] = i;
CudaNdarray_dimshuffle(new_other, self->nd, pattern);
other = new_other;
}
assert(self->nd == other->nd);
//standard elemwise dim checks (also compute total size)
unsigned int size = 1;
unsigned int size_source = 1;
......@@ -2812,13 +2826,15 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
" destination=%d, source=%d",
i, CudaNdarray_HOST_DIMS(self)[i],
CudaNdarray_HOST_DIMS(other)[i]);
return -1;
Py_XDECREF(new_other);
return -1;
}
size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i];
size_source *= (unsigned int) CudaNdarray_HOST_DIMS(other)[i];
}
if (0 == size)
{
Py_XDECREF(new_other);
return 0; //nothing to copy, we're done.
}
if (CudaNdarray_is_c_contiguous(self) &&
......@@ -2831,6 +2847,7 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
cublasScopy(size, CudaNdarray_DEV_DATA(other), 1,
CudaNdarray_DEV_DATA(self), 1);
CNDA_THREAD_SYNC;
Py_XDECREF(new_other);
if (CUBLAS_STATUS_SUCCESS != cublasGetError())
{
PyErr_SetString(PyExc_RuntimeError, "Error copying memory");
......@@ -2868,6 +2885,7 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
"Cuda error: %s: %s. (n_blocks=%i,"
" n_threads_per_block=%i)\n", "k_copy_1d",
cudaGetErrorString(err), n_blocks, n_threads);
Py_XDECREF(new_other);
return -1;
}
}; break;
......@@ -2912,10 +2930,12 @@ int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
"k_elemwise_unary_rowmajor_copy",
cudaGetErrorString(err), n_blocks,
threads_per_block);
Py_XDECREF(new_other);
return -1;
}
}
};
Py_XDECREF(new_other);
return 0;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论