提交 bc1320a5 authored 作者: Frederic's avatar Frederic

Allow a_cuda_ndarray.strides = [...]

上级 09017085
...@@ -2391,23 +2391,38 @@ CudaNdarray_get_strides(CudaNdarray *self, void *closure) ...@@ -2391,23 +2391,38 @@ CudaNdarray_get_strides(CudaNdarray *self, void *closure)
static int static int
CudaNdarray_set_strides(CudaNdarray *self, PyObject *value, void *closure) CudaNdarray_set_strides(CudaNdarray *self, PyObject *value, void *closure)
{ {
if (!PyTuple_Check(value)){ //npy_intp newstrides_bytes[PyTuple_Size(value)];
if (PyTuple_Check(value)){
if (PyTuple_Size(value) != CudaNdarray_NDIM(self)){
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"The new strides need to be encoded in a tupe"); "The new strides tuple must have the same lenght"
" as the number of dimensions");
return -1; return -1;
} }
if (PyTuple_Size(value) != CudaNdarray_NDIM(self)){ }else if (PyList_Check(value)){
if (PyList_Size(value) != CudaNdarray_NDIM(self)){
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"The new strides tuple must have the same lenght" "The new strides tuple must have the same lenght"
" as the number of dimensions"); " as the number of dimensions");
return -1; return -1;
} }
npy_intp newstrides[PyTuple_Size(value)]; }else{
//npy_intp newstrides_bytes[PyTuple_Size(value)]; PyErr_SetString(PyExc_ValueError,
"The new strides need to be encoded in a tuple or list");
return -1;
}
npy_intp newstrides[CudaNdarray_NDIM(self)];
if (PyTuple_Check(value)){
for(int i=0; i < CudaNdarray_NDIM(self); i++){ for(int i=0; i < CudaNdarray_NDIM(self); i++){
newstrides[i] = PyInt_AsLong(PyTuple_GetItem(value, Py_ssize_t(i))); newstrides[i] = PyInt_AsLong(PyTuple_GetItem(value, Py_ssize_t(i)));
//newstrides_bytes[i] = newstrides[i] * 4; //newstrides_bytes[i] = newstrides[i] * 4;
} }
}else if (PyList_Check(value)){
for(int i=0; i < CudaNdarray_NDIM(self); i++){
newstrides[i] = PyInt_AsLong(PyList_GetItem(value, Py_ssize_t(i)));
//newstrides_bytes[i] = newstrides[i] * 4;
}
}
/* /*
// Don't do the check as ExtractDiag need that and NumPy seam to don't do // Don't do the check as ExtractDiag need that and NumPy seam to don't do
// it. // it.
......
...@@ -944,12 +944,16 @@ def test_base(): ...@@ -944,12 +944,16 @@ def test_base():
def test_set_strides(): def test_set_strides():
a = cuda_ndarray.CudaNdarray.zeros((5, 5)) a = cuda_ndarray.CudaNdarray.zeros((5, 5))
a.strides = (a.strides[1], a.strides[0])
try: # Test with tuple
new_strides = (a.strides[1], a.strides[0])
a.strides = new_strides
assert a.strides == new_strides
# Test with list
new_strides = (a.strides[1], a.strides[0])
a.strides = [a.strides[1], a.strides[0]] a.strides = [a.strides[1], a.strides[0]]
assert False assert a.strides == new_strides
except ValueError:
pass
try: try:
a.strides = (a.strides[1],) a.strides = (a.strides[1],)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论