提交 32203693 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Only perform expensive check if the cast is unsafe

上级 19d49319
...@@ -6552,9 +6552,11 @@ class AdvancedSubtensor1(Op): ...@@ -6552,9 +6552,11 @@ class AdvancedSubtensor1(Op):
# many elements on a 32-bit arch). # many elements on a 32-bit arch).
if i.dtype != numpy.intp: if i.dtype != numpy.intp:
i_ = theano._asarray(i, dtype=numpy.intp) i_ = theano._asarray(i, dtype=numpy.intp)
if not numpy.can_cast(i.dtype, numpy.intp):
# Check if there was actually an incorrect conversion
if numpy.any(i != i_): if numpy.any(i != i_):
raise IndexError('index contains values that are bigger than ' raise IndexError('index contains values that are bigger '
'the maximum array size on this system.', i) 'than the maximum array size on this system.', i)
i = i_ i = i_
out[0] = x.take(i, axis=0, out=o) out[0] = x.take(i, axis=0, out=o)
...@@ -6603,6 +6605,8 @@ class AdvancedSubtensor1(Op): ...@@ -6603,6 +6605,8 @@ class AdvancedSubtensor1(Op):
if (PyArray_TYPE(%(i_name)s) != NPY_INTP) { if (PyArray_TYPE(%(i_name)s) != NPY_INTP) {
// Cast %(i_name)s to NPY_INTP (expected by PyArray_TakeFrom), // Cast %(i_name)s to NPY_INTP (expected by PyArray_TakeFrom),
// if all values fit. // if all values fit.
if (!PyArray_CanCastSafely(PyArray_TYPE(%(i_name)s), NPY_INTP))
{
PyObject* py_min_val, py_max_val; PyObject* py_min_val, py_max_val;
npy_int64 min_val, max_val; npy_int64 min_val, max_val;
py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS, min_val); py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS, min_val);
...@@ -6619,6 +6623,7 @@ class AdvancedSubtensor1(Op): ...@@ -6619,6 +6623,7 @@ class AdvancedSubtensor1(Op):
"size on this system."); "size on this system.");
%(fail)s; %(fail)s;
} }
}
indices = PyArray_Cast(%(i_name)s, NPY_INTP); indices = PyArray_Cast(%(i_name)s, NPY_INTP);
if (indices == NULL) { if (indices == NULL) {
%(fail)s; %(fail)s;
...@@ -6659,7 +6664,7 @@ class AdvancedSubtensor1(Op): ...@@ -6659,7 +6664,7 @@ class AdvancedSubtensor1(Op):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 0, 4) return (0, 0, 5)
advanced_subtensor1 = AdvancedSubtensor1() advanced_subtensor1 = AdvancedSubtensor1()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论