提交 4ae437db authored 作者: James Bergstra's avatar James Bergstra

cuda - Made CudaNdarray __getitem__ handle int-like objects in addition to true

integers. Obvious example - numpy.int* and numpy.uint* objects.
上级 3bc48cc9
...@@ -956,21 +956,26 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -956,21 +956,26 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
CudaNdarray * self = (CudaNdarray*) py_self; CudaNdarray * self = (CudaNdarray*) py_self;
PyObject * py_rval = NULL; PyObject * py_rval = NULL;
CudaNdarray * rval = NULL; CudaNdarray * rval = NULL;
PyObject * intobj = NULL;
//PyObject_Print(key, stderr, 0);
if (key == Py_Ellipsis) if (key == Py_Ellipsis)
{ {
Py_INCREF(py_self); Py_INCREF(py_self);
return py_self; return py_self;
} }
else if (PyInt_Check(key)) //INDEXING BY INTEGER if ((intobj=PyNumber_Int(key))) //INDEXING BY INTEGER
//else if (PyInt_Check(key)) //INDEXING BY INTEGER
{ {
int d_idx = PyInt_AsLong(intobj);
Py_DECREF(intobj); intobj=NULL;
//int d_idx = PyInt_AsLong(key);
if (self->nd == 0) if (self->nd == 0)
{ {
PyErr_SetString(PyExc_NotImplementedError, "index into 0-d array"); PyErr_SetString(PyExc_NotImplementedError, "index into 0-d array");
return NULL; return NULL;
} }
int d_idx = PyInt_AsLong(key);
int d_dim = CudaNdarray_HOST_DIMS(self)[0]; int d_dim = CudaNdarray_HOST_DIMS(self)[0];
int offset = 0; int offset = 0;
...@@ -1009,7 +1014,11 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1009,7 +1014,11 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
CudaNdarray_set_dim(rval, d-1, CudaNdarray_HOST_DIMS(self)[d]); CudaNdarray_set_dim(rval, d-1, CudaNdarray_HOST_DIMS(self)[d]);
} }
} }
else if (PySlice_Check(key)) //INDEXING BY SLICE else
{
PyErr_Clear();
}
if (PySlice_Check(key)) //INDEXING BY SLICE
{ {
if (self->nd == 0) if (self->nd == 0)
{ {
...@@ -1057,7 +1066,7 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1057,7 +1066,7 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
CudaNdarray_set_dim(rval, d, CudaNdarray_HOST_DIMS(self)[d]); CudaNdarray_set_dim(rval, d, CudaNdarray_HOST_DIMS(self)[d]);
} }
} }
else if (PyTuple_Check(key)) //INDEXING BY TUPLE if (PyTuple_Check(key)) //INDEXING BY TUPLE
{ {
//elements of the tuple can be either integers or slices //elements of the tuple can be either integers or slices
//the dimensionality of the view we will return is diminished for each slice in the tuple //the dimensionality of the view we will return is diminished for each slice in the tuple
...@@ -1127,9 +1136,11 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1127,9 +1136,11 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
} }
++rval_d; ++rval_d;
} }
else if (PyInt_Check(key_d)) else if ((intobj=PyNumber_Int(key_d)))
{ {
int d_idx = PyInt_AsLong(key_d); int d_idx = PyInt_AsLong(intobj);
Py_DECREF(intobj);
intobj = NULL;
int d_dim = CudaNdarray_HOST_DIMS(self)[d]; int d_dim = CudaNdarray_HOST_DIMS(self)[d];
if ((d_idx >= 0) && (d_idx < d_dim)) if ((d_idx >= 0) && (d_idx < d_dim))
...@@ -1151,6 +1162,7 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1151,6 +1162,7 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
} }
else else
{ {
PyErr_Clear(); // clear the error set by PyNumber_Int
PyErr_SetString(PyExc_IndexError, "index must be either int or slice"); PyErr_SetString(PyExc_IndexError, "index must be either int or slice");
Py_DECREF(rval); Py_DECREF(rval);
return NULL; return NULL;
...@@ -1158,16 +1170,16 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1158,16 +1170,16 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
} }
} }
} }
else
{
PyErr_SetString(PyExc_NotImplementedError, "Unknown key type");
return NULL;
}
if (py_rval) if (py_rval)
{ {
if (verbose) fprint_CudaNdarray(stderr, self); if (verbose) fprint_CudaNdarray(stderr, self);
if (verbose) fprint_CudaNdarray(stderr, rval); if (verbose) fprint_CudaNdarray(stderr, rval);
} }
else
{
PyErr_SetString(PyExc_NotImplementedError, "Unknown key type");
return NULL;
}
return py_rval; return py_rval;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论