提交 17460952 authored 作者: fsavard's avatar fsavard

Simplified zeros in cudandarray

上级 9b441ceb
...@@ -268,9 +268,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape) ...@@ -268,9 +268,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
return NULL; return NULL;
} }
//fprintf(stdout, "Pattern length: %d\n", shplen); int* newdims = (int *)malloc(sizeof(int) * shplen);
int* newdims = (int *)malloc(sizeof(int) * 2 * shplen);
if (!newdims) if (!newdims)
{ {
...@@ -279,10 +277,8 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape) ...@@ -279,10 +277,8 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
return NULL; return NULL;
} }
int* newstrides = newdims + shplen;
// strides are in number of floats, not bytes // strides are in number of floats, not bytes
int cur_stride = 1; int total_elements = 1;
// start from the end to compute strides // start from the end to compute strides
for (int i = shplen-1; i >= 0; --i) for (int i = shplen-1; i >= 0; --i)
...@@ -305,25 +301,13 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape) ...@@ -305,25 +301,13 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
return NULL; return NULL;
} }
// based on alloc_contiguous, we set
// stride=0 if the dim == 1
if (shp_el == 1)
{
// broadcast
newdims[i] = 1;
newstrides[i] = 0;
}
else
{
newdims[i] = shp_el; newdims[i] = shp_el;
newstrides[i] = cur_stride;
}
cur_stride *= newdims[i]; total_elements *= newdims[i];
} }
// cur_stride now contains the size of the array, in reals // total_elements now contains the size of the array, in reals
int total_size = cur_stride * sizeof(real); int total_size = total_elements * sizeof(real);
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_new_null(); CudaNdarray* rval = (CudaNdarray*)CudaNdarray_new_null();
if (!rval) if (!rval)
...@@ -345,8 +329,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape) ...@@ -345,8 +329,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
//fprintf(stdout, "Sizeof: %d\n", total_size); //fprintf(stdout, "Sizeof: %d\n", total_size);
if (cudaSuccess != cudaMemset(rval->devdata, 0, total_size)) if (cudaSuccess != cudaMemset(rval->devdata, 0, total_size))
{ {
fprintf(stderr, "Error memsetting %d bytes of device memory.\n", cur_stride); PyErr_Format(PyExc_MemoryError, "Error memsetting %d bytes of device memory.", total_size);
PyErr_Format(PyExc_MemoryError, "Error memsetting %d bytes of device memory.", cur_stride);
free(newdims); free(newdims);
Py_DECREF(rval); Py_DECREF(rval);
return NULL; return NULL;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论