提交 bc1320a5 authored 作者: Frederic's avatar Frederic

Allow a_cuda_ndarray.strides = [...]

上级 09017085
......@@ -2391,23 +2391,38 @@ CudaNdarray_get_strides(CudaNdarray *self, void *closure)
static int
CudaNdarray_set_strides(CudaNdarray *self, PyObject *value, void *closure)
{
if (!PyTuple_Check(value)){
//npy_intp newstrides_bytes[PyTuple_Size(value)];
if (PyTuple_Check(value)){
if (PyTuple_Size(value) != CudaNdarray_NDIM(self)){
PyErr_SetString(PyExc_ValueError,
"The new strides need to be encoded in a tupe");
"The new strides tuple must have the same lenght"
" as the number of dimensions");
return -1;
}
if (PyTuple_Size(value) != CudaNdarray_NDIM(self)){
}else if (PyList_Check(value)){
if (PyList_Size(value) != CudaNdarray_NDIM(self)){
PyErr_SetString(PyExc_ValueError,
"The new strides tuple must have the same lenght"
" as the number of dimensions");
return -1;
}
npy_intp newstrides[PyTuple_Size(value)];
//npy_intp newstrides_bytes[PyTuple_Size(value)];
}else{
PyErr_SetString(PyExc_ValueError,
"The new strides need to be encoded in a tuple or list");
return -1;
}
npy_intp newstrides[CudaNdarray_NDIM(self)];
if (PyTuple_Check(value)){
for(int i=0; i < CudaNdarray_NDIM(self); i++){
newstrides[i] = PyInt_AsLong(PyTuple_GetItem(value, Py_ssize_t(i)));
//newstrides_bytes[i] = newstrides[i] * 4;
}
}else if (PyList_Check(value)){
for(int i=0; i < CudaNdarray_NDIM(self); i++){
newstrides[i] = PyInt_AsLong(PyList_GetItem(value, Py_ssize_t(i)));
//newstrides_bytes[i] = newstrides[i] * 4;
}
}
/*
// Don't do the check as ExtractDiag need that and NumPy seam to don't do
// it.
......
......@@ -944,12 +944,16 @@ def test_base():
def test_set_strides():
a = cuda_ndarray.CudaNdarray.zeros((5, 5))
a.strides = (a.strides[1], a.strides[0])
try:
# Test with tuple
new_strides = (a.strides[1], a.strides[0])
a.strides = new_strides
assert a.strides == new_strides
# Test with list
new_strides = (a.strides[1], a.strides[0])
a.strides = [a.strides[1], a.strides[0]]
assert False
except ValueError:
pass
assert a.strides == new_strides
try:
a.strides = (a.strides[1],)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论