提交 967ba312 authored 作者: abalkin's avatar abalkin

Added an explicit cast in c_code for AdvancedSubtensor1 Op to match logic in…

Added an explicit cast in c_code for AdvancedSubtensor1 Op to match logic in perform() and avoid test failures on 32 bit platforms. The real issue seem to be that x.shape is a 64-bit tensor regardless of the size of platform's size_t type. This leads to int64 type appearance in any computation involving shape and resulting index cannot be passed back to numpy.take() necessitating the cast.
上级 856c24a6
......@@ -6559,23 +6559,30 @@ class AdvancedSubtensor1(Op):
output_name = output_names[0]
fail = sub['fail']
return """
PyObject *indices;
// This cast makes c_code mimic the logic in perform(), but
// also makes theano code less safe.
indices = PyArray_Cast(%(i_name)s, NPY_INTP);
if (indices == NULL) {
%(fail)s;
}
if (%(output_name)s != NULL) {
npy_intp nd, i, *shape;
nd = PyArray_NDIM(%(a_name)s) + PyArray_NDIM(%(i_name)s) - 1;
nd = PyArray_NDIM(%(a_name)s) + PyArray_NDIM(indices) - 1;
if (PyArray_NDIM(%(output_name)s) != nd) {
Py_CLEAR(%(output_name)s);
}
else {
shape = PyArray_DIMS(%(output_name)s);
for (i = 0; i < PyArray_NDIM(%(i_name)s); i++) {
if (shape[i] != PyArray_DIMS(%(i_name)s)[i]) {
for (i = 0; i < PyArray_NDIM(indices); i++) {
if (shape[i] != PyArray_DIMS(indices)[i]) {
Py_CLEAR(%(output_name)s);
break;
}
}
if (%(output_name)s != NULL) {
for (; i < nd; i++) {
if (shape[i] != PyArray_DIMS(%(a_name)s)[i-PyArray_NDIM(%(i_name)s)+1]) {
if (shape[i] != PyArray_DIMS(%(a_name)s)[i-PyArray_NDIM(indices)+1]) {
Py_CLEAR(%(output_name)s);
break;
}
......@@ -6583,13 +6590,14 @@ class AdvancedSubtensor1(Op):
}
}
}
%(output_name)s = (PyArrayObject*)PyArray_TakeFrom(%(a_name)s, (PyObject*)%(i_name)s, 0,
%(output_name)s = (PyArrayObject*)PyArray_TakeFrom(%(a_name)s, indices, 0,
%(output_name)s, NPY_RAISE);
Py_DECREF(indices);
if (%(output_name)s == NULL) %(fail)s;
""" % locals()
def c_code_cache_version(self):
return (0, 0, 2)
return (0, 0, 3)
advanced_subtensor1 = AdvancedSubtensor1()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论