提交 bc1320a5 authored 作者: Frederic's avatar Frederic

Allow a_cuda_ndarray.strides = [...]

上级 09017085
...@@ -2391,22 +2391,37 @@ CudaNdarray_get_strides(CudaNdarray *self, void *closure) ...@@ -2391,22 +2391,37 @@ CudaNdarray_get_strides(CudaNdarray *self, void *closure)
static int static int
CudaNdarray_set_strides(CudaNdarray *self, PyObject *value, void *closure) CudaNdarray_set_strides(CudaNdarray *self, PyObject *value, void *closure)
{ {
if (!PyTuple_Check(value)){ //npy_intp newstrides_bytes[PyTuple_Size(value)];
PyErr_SetString(PyExc_ValueError, if (PyTuple_Check(value)){
"The new strides need to be encoded in a tupe"); if (PyTuple_Size(value) != CudaNdarray_NDIM(self)){
return -1; PyErr_SetString(PyExc_ValueError,
} "The new strides tuple must have the same lenght"
if (PyTuple_Size(value) != CudaNdarray_NDIM(self)){ " as the number of dimensions");
return -1;
}
}else if (PyList_Check(value)){
if (PyList_Size(value) != CudaNdarray_NDIM(self)){
PyErr_SetString(PyExc_ValueError,
"The new strides tuple must have the same lenght"
" as the number of dimensions");
return -1;
}
}else{
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"The new strides tuple must have the same lenght" "The new strides need to be encoded in a tuple or list");
" as the number of dimensions");
return -1; return -1;
} }
npy_intp newstrides[PyTuple_Size(value)]; npy_intp newstrides[CudaNdarray_NDIM(self)];
//npy_intp newstrides_bytes[PyTuple_Size(value)]; if (PyTuple_Check(value)){
for(int i=0; i < CudaNdarray_NDIM(self); i++){ for(int i=0; i < CudaNdarray_NDIM(self); i++){
newstrides[i] = PyInt_AsLong(PyTuple_GetItem(value, Py_ssize_t(i))); newstrides[i] = PyInt_AsLong(PyTuple_GetItem(value, Py_ssize_t(i)));
//newstrides_bytes[i] = newstrides[i] * 4; //newstrides_bytes[i] = newstrides[i] * 4;
}
}else if (PyList_Check(value)){
for(int i=0; i < CudaNdarray_NDIM(self); i++){
newstrides[i] = PyInt_AsLong(PyList_GetItem(value, Py_ssize_t(i)));
//newstrides_bytes[i] = newstrides[i] * 4;
}
} }
/* /*
// Don't do the check as ExtractDiag need that and NumPy seam to don't do // Don't do the check as ExtractDiag need that and NumPy seam to don't do
......
...@@ -944,12 +944,16 @@ def test_base(): ...@@ -944,12 +944,16 @@ def test_base():
def test_set_strides(): def test_set_strides():
a = cuda_ndarray.CudaNdarray.zeros((5, 5)) a = cuda_ndarray.CudaNdarray.zeros((5, 5))
a.strides = (a.strides[1], a.strides[0])
try: # Test with tuple
a.strides = [a.strides[1], a.strides[0]] new_strides = (a.strides[1], a.strides[0])
assert False a.strides = new_strides
except ValueError: assert a.strides == new_strides
pass
# Test with list
new_strides = (a.strides[1], a.strides[0])
a.strides = [a.strides[1], a.strides[0]]
assert a.strides == new_strides
try: try:
a.strides = (a.strides[1],) a.strides = (a.strides[1],)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论