提交 a36d54ae authored 作者: abalkin's avatar abalkin

Added error checking code.

上级 012ff768
...@@ -6611,21 +6611,35 @@ class AdvancedSubtensor1(Op): ...@@ -6611,21 +6611,35 @@ class AdvancedSubtensor1(Op):
fail = sub['fail'] fail = sub['fail']
return """ return """
PyObject *indices; PyObject *indices;
if (PyArray_TYPE(%(i_name)s) != NPY_INTP) { int i_type = PyArray_TYPE(%(i_name)s);
if (i_type != NPY_INTP) {
// Cast %(i_name)s to NPY_INTP (expected by PyArray_TakeFrom), // Cast %(i_name)s to NPY_INTP (expected by PyArray_TakeFrom),
// if all values fit. // if all values fit.
if (!PyArray_CanCastSafely(PyArray_TYPE(%(i_name)s), NPY_INTP)) if (!PyTypeNum_ISINTEGER(i_type)) {
{ PyErr_SetString(PyExc_TypeError, "Index must be an integer tensor.");
%(fail)s;
}
if (!PyArray_CanCastSafely(i_type, NPY_INTP)) {
npy_int64 min_val, max_val; npy_int64 min_val, max_val;
PyObject* py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS, NULL); PyObject* py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS, NULL);
PyObject* py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS, NULL); if (py_min_val == NULL) {
%(fail)s;
}
min_val = PyLong_AsLongLong(py_min_val); min_val = PyLong_AsLongLong(py_min_val);
Py_DECREF(py_min_val);
if (min_val == -1 && PyErr_Occurred()) {
%(fail)s;
}
PyObject* py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS, NULL);
if (py_max_val == NULL) {
%(fail)s;
}
max_val = PyLong_AsLongLong(py_max_val); max_val = PyLong_AsLongLong(py_max_val);
Py_CLEAR(py_min_val); Py_DECREF(py_max_val);
Py_CLEAR(py_max_val); if (max_val == -1 && PyErr_Occurred()) {
%(fail)s;
if ((min_val < NPY_MIN_INTP) || (max_val > NPY_MAX_INTP)) }
{ if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {
PyErr_SetString(PyExc_IndexError, "Index contains values " PyErr_SetString(PyExc_IndexError, "Index contains values "
"that are bigger than the maximum array " "that are bigger than the maximum array "
"size on this system."); "size on this system.");
...@@ -6672,7 +6686,7 @@ class AdvancedSubtensor1(Op): ...@@ -6672,7 +6686,7 @@ class AdvancedSubtensor1(Op):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 0, 6) return (0, 1, 0)
advanced_subtensor1 = AdvancedSubtensor1() advanced_subtensor1 = AdvancedSubtensor1()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论