提交 4ae437db authored 作者: James Bergstra's avatar James Bergstra

cuda - Made CudaNdarray __getitem__ handle int-like objects in addition to true

integers. Obvious example - numpy.int* and numpy.uint* objects.
上级 3bc48cc9
......@@ -956,21 +956,26 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
CudaNdarray * self = (CudaNdarray*) py_self;
PyObject * py_rval = NULL;
CudaNdarray * rval = NULL;
PyObject * intobj = NULL;
//PyObject_Print(key, stderr, 0);
if (key == Py_Ellipsis)
{
Py_INCREF(py_self);
return py_self;
}
else if (PyInt_Check(key)) //INDEXING BY INTEGER
if ((intobj=PyNumber_Int(key))) //INDEXING BY INTEGER
//else if (PyInt_Check(key)) //INDEXING BY INTEGER
{
int d_idx = PyInt_AsLong(intobj);
Py_DECREF(intobj); intobj=NULL;
//int d_idx = PyInt_AsLong(key);
if (self->nd == 0)
{
PyErr_SetString(PyExc_NotImplementedError, "index into 0-d array");
return NULL;
}
int d_idx = PyInt_AsLong(key);
int d_dim = CudaNdarray_HOST_DIMS(self)[0];
int offset = 0;
......@@ -1009,7 +1014,11 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
CudaNdarray_set_dim(rval, d-1, CudaNdarray_HOST_DIMS(self)[d]);
}
}
else if (PySlice_Check(key)) //INDEXING BY SLICE
else
{
PyErr_Clear();
}
if (PySlice_Check(key)) //INDEXING BY SLICE
{
if (self->nd == 0)
{
......@@ -1057,7 +1066,7 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
CudaNdarray_set_dim(rval, d, CudaNdarray_HOST_DIMS(self)[d]);
}
}
else if (PyTuple_Check(key)) //INDEXING BY TUPLE
if (PyTuple_Check(key)) //INDEXING BY TUPLE
{
//elements of the tuple can be either integers or slices
//the dimensionality of the view we will return is diminished for each slice in the tuple
......@@ -1127,9 +1136,11 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
}
++rval_d;
}
else if (PyInt_Check(key_d))
else if ((intobj=PyNumber_Int(key_d)))
{
int d_idx = PyInt_AsLong(key_d);
int d_idx = PyInt_AsLong(intobj);
Py_DECREF(intobj);
intobj = NULL;
int d_dim = CudaNdarray_HOST_DIMS(self)[d];
if ((d_idx >= 0) && (d_idx < d_dim))
......@@ -1151,6 +1162,7 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
}
else
{
PyErr_Clear(); // clear the error set by PyNumber_Int
PyErr_SetString(PyExc_IndexError, "index must be either int or slice");
Py_DECREF(rval);
return NULL;
......@@ -1158,16 +1170,16 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
}
}
}
else
{
PyErr_SetString(PyExc_NotImplementedError, "Unknown key type");
return NULL;
}
if (py_rval)
{
if (verbose) fprint_CudaNdarray(stderr, self);
if (verbose) fprint_CudaNdarray(stderr, rval);
}
else
{
PyErr_SetString(PyExc_NotImplementedError, "Unknown key type");
return NULL;
}
return py_rval;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论