提交 17460952 authored 作者: fsavard's avatar fsavard

Simplified zeros in cudandarray

上级 9b441ceb
......@@ -268,9 +268,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
return NULL;
}
//fprintf(stdout, "Pattern length: %d\n", shplen);
int* newdims = (int *)malloc(sizeof(int) * 2 * shplen);
int* newdims = (int *)malloc(sizeof(int) * shplen);
if (!newdims)
{
......@@ -279,10 +277,8 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
return NULL;
}
int* newstrides = newdims + shplen;
// strides are in number of floats, not bytes
int cur_stride = 1;
int total_elements = 1;
// start from the end to compute strides
for (int i = shplen-1; i >= 0; --i)
......@@ -305,25 +301,13 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
return NULL;
}
// based on alloc_contiguous, we set
// stride=0 if the dim == 1
if (shp_el == 1)
{
// broadcast
newdims[i] = 1;
newstrides[i] = 0;
}
else
{
newdims[i] = shp_el;
newstrides[i] = cur_stride;
}
newdims[i] = shp_el;
cur_stride *= newdims[i];
total_elements *= newdims[i];
}
// cur_stride now contains the size of the array, in reals
int total_size = cur_stride * sizeof(real);
// total_elements now contains the size of the array, in reals
int total_size = total_elements * sizeof(real);
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_new_null();
if (!rval)
......@@ -345,8 +329,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
//fprintf(stdout, "Sizeof: %d\n", total_size);
if (cudaSuccess != cudaMemset(rval->devdata, 0, total_size))
{
fprintf(stderr, "Error memsetting %d bytes of device memory.\n", cur_stride);
PyErr_Format(PyExc_MemoryError, "Error memsetting %d bytes of device memory.", cur_stride);
PyErr_Format(PyExc_MemoryError, "Error memsetting %d bytes of device memory.", total_size);
free(newdims);
Py_DECREF(rval);
return NULL;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论