提交 7258afc9 authored 作者: abalkin's avatar abalkin

Issue #1185 (Implemented c_code for AdvancedSubtensor1): Do not cast if indices…

Issue #1185 (Implemented c_code for AdvancedSubtensor1): Do not cast if indices array already has correct dtype.
上级 a3282db2
......@@ -6549,7 +6549,8 @@ class AdvancedSubtensor1(Op):
# TypeError: array cannot be safely cast to required type.
# Since we will probably not have an array with more than 2**31 items
# on a 32-bit arch, I suppose it is safe to cast i into intp.
i = theano._asarray(i, dtype=numpy.intp)
if i.dtype != numpy.intp:
i = theano._asarray(i, dtype=numpy.intp)
out[0] = x.take(i, axis=0, out=o)
......@@ -6594,11 +6595,17 @@ class AdvancedSubtensor1(Op):
fail = sub['fail']
return """
PyObject *indices;
// This cast makes c_code mimic the logic in perform(), but
// also makes theano code less safe.
indices = PyArray_Cast(%(i_name)s, NPY_INTP);
if (indices == NULL) {
%(fail)s;
if (PyArray_TYPE(%(i_name)s) != NPY_INTP) {
// This cast makes c_code mimic the logic in perform(), but
// also makes theano code less safe.
indices = PyArray_Cast(%(i_name)s, NPY_INTP);
if (indices == NULL) {
%(fail)s;
}
}
else {
indices = (PyObject *)%(i_name)s;
Py_INCREF(indices);
}
if (%(output_name)s != NULL) {
npy_intp nd, i, *shape;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论