提交 5e6ce10a authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: David Warde-Farley

Ensure that indices are suitable for numpy.take.

On a 32-bit machine, numpy.take may refuse to use int64 arrays for indices. (incl. typo fix by @dwf)
上级 9bd1bff5
...@@ -39,16 +39,18 @@ def _asarray(a, dtype, order=None): ...@@ -39,16 +39,18 @@ def _asarray(a, dtype, order=None):
if rval.dtype.num != dtype.num: if rval.dtype.num != dtype.num:
# Type mismatch between the data type we asked for, and the one # Type mismatch between the data type we asked for, and the one
# returned by numpy.asarray. # returned by numpy.asarray.
if (dtype.num == numpy.dtype(numpy.int32).num or # If both types have the same string description (byte order, basic
dtype.num == numpy.dtype(numpy.int64).num): # type, and number of bytes), then it is safe to return a view.
if (dtype.str == rval.dtype.str):
# Silent fix. # Silent fix.
return rval.view(dtype=dtype) return rval.view(dtype=dtype)
else: else:
# Unexpected mismatch: better know what is going on! # Unexpected mismatch: better know what is going on!
raise TypeError('numpy.array did not return the data type we ' raise TypeError('numpy.array did not return the data type we '
'asked for (%s #%s), instead it returned type %s #%s: function ' 'asked for (%s %s #%s), instead it returned type '
'theano._asarray may need to be extended to handle this ' '%s %s #%s: function '
'data type as well.' % 'theano._asarray may need to be modified to handle this '
(dtype, dtype.num, rval.dtype, rval.dtype.num)) 'data type.' %
(dtype, dtype.str, dtype.num, rval.dtype, rval.str, rval.dtype.num))
else: else:
return rval return rval
...@@ -4909,8 +4909,13 @@ class AdvancedSubtensor1(Op): ...@@ -4909,8 +4909,13 @@ class AdvancedSubtensor1(Op):
else: else:
o = None o = None
# I have read that using clip or wrap mode make it a lot faster when # If i.dtype is more precise than numpy.intc (int32 on 32-bit machines,
# the output is provided # int64 on 64-bit machines), numpy may raise the following error:
# TypeError: array cannot be safely cast to required type.
# Since we will probably not have an array with more than 2**31 items
# on a 32-bit arch, I suppose it is safe to cast i into intc.
i = theano._asarray(i, dtype=numpy.intc)
out[0] = x.take(i, axis=0, out=o) out[0] = x.take(i, axis=0, out=o)
def grad(self, inputs, grads): def grad(self, inputs, grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论